Data parallelization and batch prefetching

class mlip.data.helpers.data_prefetching.ParallelGraphDataset(graph_dataset: GraphDataset, num_parallel: int)

A graph dataset that loads multiple batches in parallel.

__init__(graph_dataset: GraphDataset, num_parallel: int)

Constructor.

Parameters:
  • graph_dataset – The standard graph dataset to parallelize.

  • num_parallel – Number of parallel batches to process.

__iter__()

The iterator for this parallel graph dataset.

__len__()

Returns the number of batches in the underlying graph dataset.

class mlip.data.helpers.data_prefetching.PrefetchIterator(iterable: Iterable, prefetch_count: int = 1, preprocess_fn: Callable | None = None)

A class to prefetch items from an iterable, with an option to preprocess each item.

iterable

The original iterable.

queue

A queue to hold the prefetched items.

preprocess_fn

An optional function to preprocess each item.

thread

The thread used for prefetching.

Example:

def double(x):
    return x * 2

it = PrefetchIterator(range(5), prefetch_count=2, preprocess_fn=double)
for i in it:
    print(i)
# Outputs: 0, 2, 4, 6, 8
__init__(iterable: Iterable, prefetch_count: int = 1, preprocess_fn: Callable | None = None)

Constructor.

Parameters:
  • iterable – The iterable to prefetch from.

  • prefetch_count – The maximum number of items to prefetch. Defaults to 1.

  • preprocess_fn – A function to preprocess each item. Should accept a single argument and return the processed item. Defaults to None.

__iter__()

Implementation of the iterator. It starts a new thread once completed.

__len__()

Returns the length of the underlying iterable.