Data parallelization and batch prefetching¶
- class mlip.data.helpers.data_prefetching.ParallelGraphDataset(graph_dataset: GraphDataset, num_parallel: int)¶
A graph dataset that loads multiple batches in parallel.
- __init__(graph_dataset: GraphDataset, num_parallel: int)¶
Constructor.
- Parameters:
graph_dataset – The standard graph dataset to parallelize.
num_parallel – Number of parallel batches to process.
- __iter__()¶
The iterator for this parallel graph dataset.
- __len__()¶
Returns the number of batches in the underlying graph dataset.
- class mlip.data.helpers.data_prefetching.PrefetchIterator(iterable: Iterable, prefetch_count: int = 1, preprocess_fn: Callable | None = None)¶
A class to prefetch items from an iterable, with an option to preprocess each item.
- iterable¶
The original iterable.
- queue¶
A queue to hold the prefetched items.
- preprocess_fn¶
An optional function to preprocess each item.
- thread¶
The thread used for prefetching.
Example:
def double(x): return x * 2 it = PrefetchIterator(range(5), prefetch_count=2, preprocess_fn=double) for i in it: print(i) # Outputs: 0, 2, 4, 6, 8
- __init__(iterable: Iterable, prefetch_count: int = 1, preprocess_fn: Callable | None = None)¶
Constructor.
- Parameters:
iterable – The iterable to prefetch from.
prefetch_count – The maximum number of items to prefetch. Defaults to 1.
preprocess_fn – A function to preprocess each item. Should accept a single argument and return the processed item. Defaults to None.
- __iter__()¶
Implementation of the iterator. It starts a new thread once completed.
- __len__()¶
Returns the length of the underlying iterable.