Coverage for instanovo/utils/data_handler.py: 61%
798 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import asyncio
4import glob
5import os
6import random
7import re
8import shutil
9import tempfile
10import time
11import uuid
12from concurrent.futures import ThreadPoolExecutor
13from enum import Enum
14from itertools import chain
15from pathlib import Path
16from typing import TYPE_CHECKING, Any, Callable, Iterator, Union, cast
18import numpy as np
19import pandas as pd
20import polars as pl
21from datasets import Dataset, Features, Sequence, Value, VerificationMode, load_dataset
22from matchms import Spectrum
23from matchms.exporting import save_as_mgf
24from matchms.importing import load_from_mgf
26from instanovo.constants import (
27 ANNOTATED_COLUMN,
28 ANNOTATION_ERROR,
29 MS_TYPES,
30 PROTON_MASS_AMU,
31 MSColumns,
32)
34if TYPE_CHECKING:
35 from instanovo.utils.metrics import Metrics
36from instanovo.__init__ import console
37from instanovo.utils.colorlogging import ColorLog
38from instanovo.utils.msreader import read_mzml, read_mzxml
40logger = ColorLog(console, __name__).logger
43class SpectrumDataFrame:
44 """Spectra data class.
46 A data class for interacting with mass spectra data, including loading, processing,
47 and saving mass spectrometry data in various formats such as Parquet, MGF, and others.
49 Supports lazy loading, shuffling, and handling of large datasets by processing them
50 in chunks.
52 Attributes:
53 is_annotated: Whether the dataset is annotated with peptide sequences.
54 has_predictions: Whether the dataset contains predictions.
55 is_lazy: Whether lazy loading mode is used.
56 """
58 def __init__(
59 self,
60 df: pl.DataFrame | None = None,
61 file_paths: str | list[str] | None = None,
62 is_annotated: bool = False,
63 has_predictions: bool = False,
64 is_lazy: bool = False,
65 shuffle: bool = False,
66 preshuffle_across_shards: bool = True,
67 max_shard_size: int = 100_000,
68 custom_load_fn: Callable | None = None,
69 column_mapping: dict[str, str] | None = None,
70 add_source_file_column: bool = False,
71 preprocess_fn: Callable | None = None,
72 force_convert_to_native: bool = False,
73 verbose: bool = False,
74 add_spectrum_id: bool = False,
75 force_spectrum_id: bool = False,
76 ) -> None:
77 """Initialize SpectrumDataFrame.
79 Args:
80 df (pl.DataFrame | None): In-memory polars DataFrame with mass spectra data.
81 file_paths (str | list[str] | None): Path(s) to the input data files.
82 custom_load_fn (Callable | None): Custom function for loading data files.
83 is_annotated (bool): Whether the dataset is annotated.
84 has_predictions (bool): Whether predictions are present in the dataset.
85 is_lazy (bool): Whether to use lazy loading mode.
86 shuffle (bool): Whether to shuffle the dataset.
87 max_shard_size (int): Maximum size of data shards for chunking large datasets.
88 add_source_file_column (bool): Add source file column to the data.
89 preprocess_fn (Callable | None): Preprocess function for the data on load.
90 force_convert_to_native (bool): Force conversion to native format.
91 verbose (bool): Whether to print verbose output.
92 add_spectrum_id (bool): Add spectrum id column to the data.
93 force_spectrum_id (bool): Force addition of spectrum id column to the data.
95 Raises:
96 ValueError: If neither `df` nor `file_paths` is specified, or both are given.
97 FileNotFoundError: If no files are found matching the given `file_paths`.
98 """
99 self._is_annotated: bool = is_annotated
100 self._has_predictions: bool = has_predictions # if we have outputs/predictions
101 self._is_lazy: bool = is_lazy # or streaming
102 self._shuffle: bool = shuffle
103 self._max_shard_size = max_shard_size
104 self._custom_load_fn = custom_load_fn
105 self._add_source_file_column = add_source_file_column
106 self._verbose = verbose
107 self._add_spectrum_id = add_spectrum_id
108 self._force_spectrum_id = force_spectrum_id
109 self.executor = None
110 self._temp_directory = None
111 # String representation values:
112 self.max_items_per_column = 3
113 self.max_colname_length = 20
114 self.preprocess_fn = preprocess_fn
115 self.df: pl.DataFrame | None = None
117 if df is None and file_paths is None:
118 raise ValueError("Must specify either df or file_paths, both are None.")
120 # native refers to data being stored as a list of parquet files.
121 self._is_native = file_paths is not None
122 if df is not None and self._is_native:
123 raise ValueError("Must specify either df or file_paths, not both.")
125 if self._is_native:
126 # Get all file paths
127 self._file_paths = SpectrumDataFrame._convert_file_paths(cast(str | list[str], file_paths))
129 if len(self._file_paths) == 0:
130 raise FileNotFoundError(f"No files matching '{file_paths}' were found.")
132 # If any of the files are not .parquet, create a tempdir with the converted files.
133 if not all(fp.lower().endswith(".parquet") for fp in self._file_paths) or force_convert_to_native:
134 # If lazy make tempdir, if not convert to non-native and load all contents into df
135 # Only iterate over non-parquet files
136 df_iterator = SpectrumDataFrame.get_data_shards(
137 self._file_paths,
138 max_shard_size=self._max_shard_size,
139 custom_load_fn=custom_load_fn,
140 column_mapping=column_mapping,
141 add_source_file_column=add_source_file_column,
142 force_convert_to_native=force_convert_to_native,
143 add_spectrum_id_column=add_spectrum_id,
144 verbose=verbose,
145 )
147 if self._is_lazy:
148 self._temp_directory = tempfile.mkdtemp()
150 new_file_paths = [fp for fp in self._file_paths if (fp.lower().endswith(".parquet") and not force_convert_to_native)]
151 for temp_df in df_iterator:
152 # TODO: better way to generate id than hash?
153 temp_parquet_path = os.path.join(self._temp_directory, f"temp_{uuid.uuid4().hex}.parquet")
154 temp_df.write_parquet(temp_parquet_path)
155 new_file_paths.append(temp_parquet_path)
156 self._log(f"Saving temporary file to {temp_parquet_path}")
157 self._file_paths = new_file_paths
158 else:
159 self.df = None
160 for temp_df in df_iterator:
161 temp_df = SpectrumDataFrame._map_columns(temp_df, column_mapping=column_mapping)
162 temp_df = SpectrumDataFrame._cast_columns(temp_df)
163 if self.df is None:
164 self.df = temp_df
165 else:
166 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df)
168 # Ensure parquet files are re-added
169 for fp in self._file_paths:
170 if not fp.lower().endswith(".parquet") or force_convert_to_native:
171 continue
172 temp_df = SpectrumDataFrame._map_columns(pl.read_parquet(fp), column_mapping=column_mapping)
173 temp_df = SpectrumDataFrame._cast_columns(temp_df)
174 temp_df = SpectrumDataFrame._ensure_experiment_name(
175 temp_df,
176 fp,
177 add_source=add_source_file_column,
178 force_source=True,
179 add_spectrum_id=add_spectrum_id,
180 force_spectrum_id=force_spectrum_id,
181 )
182 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df)
184 # Native is disabled if not lazy
185 self._is_native = False
186 self._file_paths = []
187 elif not self._is_lazy:
188 # Loaded native, convert to lazy
189 self.df = None
190 for fp in self._file_paths:
191 temp_df = SpectrumDataFrame._map_columns(pl.read_parquet(fp), column_mapping=column_mapping)
192 temp_df = SpectrumDataFrame._cast_columns(temp_df)
193 temp_df = SpectrumDataFrame._ensure_experiment_name(
194 temp_df,
195 fp,
196 add_source=add_source_file_column,
197 force_source=True,
198 add_spectrum_id=add_spectrum_id,
199 force_spectrum_id=force_spectrum_id,
200 )
201 if self.df is None:
202 self.df = temp_df
203 else:
204 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df)
206 # Native is disabled if not lazy
207 self._is_native = False
208 self._file_paths = []
210 # Create filter series for native mode
211 if self._file_paths is not None:
212 with ThreadPoolExecutor() as executor:
213 # Process files in parallel
214 self._filter_series_per_file = dict(executor.map(SpectrumDataFrame._create_filter_series, self._file_paths))
215 else:
216 self.df = df
218 # Check all columns
219 self._log("Verifying loaded data")
220 self._check_type_spec()
221 self._reset_current_file()
223 if preprocess_fn is not None and not self._is_lazy:
224 self._log("Preprocessing data in memory.")
225 self.df = preprocess_fn(self.df)
227 if self._shuffle:
228 if self._is_native:
229 if preshuffle_across_shards:
230 self._preshuffle_files()
231 self._shuffle_file_order()
232 else:
233 self.df = SpectrumDataFrame._shuffle_df(self.df)
234 elif self._is_native:
235 # Sort files alphabetically
236 # TODO: Do we want this? Keeps consistent when loading files across devices
237 self._file_paths.sort()
238 self._update_file_indices()
240 if self._is_lazy:
241 # When lazy loading, use async loaders to keep next file ready at all times.
242 self.executor = ThreadPoolExecutor(max_workers=1)
243 self.loop = asyncio.get_event_loop()
244 self._next_file: str | None = None # Future file name
245 self._next_file_future: pl.DataFrame | None = None # Future file data
246 self._preload_task: asyncio.Task | None = None
248 @staticmethod
249 def _create_filter_series(file_path: str) -> tuple[str, pl.Series]:
250 # Use lazy evaluation to get height without loading full data
251 height = pl.scan_parquet(file_path).select(pl.len()).collect().item()
252 return file_path, pl.Series(np.ones(height, dtype=bool))
254 @staticmethod
255 def _shuffle_df(df: pl.DataFrame) -> pl.DataFrame:
256 """Shuffle the rows of the given DataFrame."""
257 shuffled_indices = np.random.permutation(len(df))
258 return df[shuffled_indices]
259 # return df.with_row_count("row_nr").select(pl.all().shuffle())
261 @staticmethod
262 def _sanitise_peptide(peptide: str) -> str:
263 """Sanitise peptide sequence."""
264 # Some datasets save sequence wrapped with _ or .
265 if peptide is None:
266 return None
267 if peptide[0] == "_" and peptide[-1] == "_":
268 peptide = peptide[1:-1]
269 if peptide[0] == "." and peptide[-1] == ".":
270 peptide = peptide[1:-1]
271 return peptide
273 @staticmethod
274 def _ensure_experiment_name(
275 df: pl.DataFrame,
276 file_path: str,
277 add_source: bool = False,
278 force_source: bool = False,
279 add_spectrum_id: bool = False,
280 force_spectrum_id: bool = False,
281 ) -> pl.DataFrame:
282 """Ensure experiment_name is a column in the df."""
283 if "experiment_name" not in df.columns:
284 exp_name = Path(file_path).stem
285 df = df.with_columns(pl.lit(exp_name).alias("experiment_name").cast(pl.Utf8))
286 if add_source and (("source_file" not in df.columns) or force_source):
287 df = df.with_columns(pl.lit(file_path).alias("source_file").cast(pl.Utf8))
288 if add_spectrum_id and (("spectrum_id" not in df.columns) or force_spectrum_id):
289 if "scan_number" not in df.columns:
290 logger.warning("Scan number column not found. Creating row-based idx column for spectrum_id.")
291 df = df.with_row_index("idx")
292 df = df.with_columns((pl.col("experiment_name") + ":" + pl.col("idx").cast(pl.Utf8)).alias("spectrum_id"))
293 df = df.drop("idx")
294 else:
295 df = df.with_columns((pl.col("experiment_name") + ":" + pl.col("scan_number").cast(pl.Utf8)).alias("spectrum_id"))
296 return df
298 @staticmethod
299 def _map_columns(df: pl.DataFrame, column_mapping: dict[str, str] | None = None) -> pl.DataFrame:
300 """Map the columns of the DataFrame to the appropriate data types based on MS_TYPES."""
301 if column_mapping is None:
302 return df
303 return df.rename({k: v for k, v in column_mapping.items() if k in df.columns})
305 @staticmethod
306 def _cast_columns(df: pl.DataFrame) -> pl.DataFrame:
307 """Cast the columns of the DataFrame to the appropriate data types based on MS_TYPES."""
308 return df.with_columns([pl.col(column.value).cast(dtype) for column, dtype in MS_TYPES.items() if column.value in df.columns])
310 @staticmethod
311 def _is_glob(path: str) -> bool:
312 return "*" in path or "?" in path or "[" in path
314 @staticmethod
315 def _convert_file_paths(file_paths: str | list[str]) -> list[str]:
316 """Convert a string or list of file paths to a list of file paths.
318 Args:
319 file_paths (str | list[str]): File path or list of file paths.
321 Returns:
322 list[str]: A list of resolved file paths.
324 Raises:
325 ValueError: If input is a directory or not a valid file path.
326 """
327 if isinstance(file_paths, str):
328 if os.path.isdir(file_paths):
329 raise ValueError("Input must be a string (filepath or glob) or a list of file paths. Found directory.")
330 if SpectrumDataFrame._is_glob(file_paths):
331 # Glob notation: path/to/data/*.parquet
332 return glob.glob(file_paths)
333 else:
334 # Single file
335 return [file_paths]
336 elif not isinstance(file_paths, list):
337 ValueError("Input must be a string (filepath or glob) or a list of file paths.")
339 # Expand if list of globs
340 file_paths = list(chain.from_iterable([glob.glob(path) if SpectrumDataFrame._is_glob(path) else [path] for path in file_paths]))
342 return file_paths
344 @staticmethod
345 def _concat_dataframes(df1: pl.DataFrame, df2: pl.DataFrame) -> pl.DataFrame:
346 df1_columns = set(df1.columns)
347 df2_columns = set(df2.columns)
349 # Find missing columns in both DataFrames
350 missing_in_df1 = df2_columns - df1_columns
351 missing_in_df2 = df1_columns - df2_columns
353 # Add missing columns to df1 with None values
354 for col in missing_in_df1:
355 df1 = df1.with_columns(pl.lit(None).cast(df2[col].dtype).alias(col))
357 # Add missing columns to df2 with None values
358 for col in missing_in_df2:
359 df2 = df2.with_columns(pl.lit(None).cast(df1[col].dtype).alias(col))
361 # Rearrange df2 to have the same order as df1
362 df2 = df2.select(df1.columns)
364 return pl.concat([df1, df2], how="vertical_relaxed")
366 @staticmethod
367 def get_data_shards(
368 file_paths: list[str],
369 custom_load_fn: Callable | None = None,
370 column_mapping: dict[str, str] | None = None,
371 max_shard_size: int = 100_000,
372 add_source_file_column: bool = False,
373 force_convert_to_native: bool = False,
374 add_spectrum_id_column: bool = False,
375 verbose: bool = False,
376 ) -> Iterator[pl.DataFrame]:
377 """Load data files into DataFrames one at a time to save memory.
379 Args:
380 file_paths (list[str]): List of file paths to be loaded.
381 custom_load_fn (Callable | None): Custom function to load the files.
382 max_shard_size (int): Maximum size of data shards.
383 verbose (bool): Whether to using logger
384 force_convert_to_native (bool): Force conversion to native format.
385 add_spectrum_id_column (bool): Whether to add spectrum id column.
387 Yields:
388 Iterator[pl.DataFrame]: DataFrames containing mass spectra data.
389 """
390 column_mapping = column_mapping or {}
391 current_shard = None
392 for i, fp in enumerate(file_paths, 1):
393 if verbose:
394 logger.info(f"Loading file {i:03,d} of {len(file_paths):03,d}: {fp}")
396 if fp.endswith(".parquet") and not force_convert_to_native:
397 continue
399 if custom_load_fn is not None:
400 df = custom_load_fn(fp)
401 else:
402 df = SpectrumDataFrame._df_from_any(fp)
404 if df is None:
405 logger.warning(f"Unknown filetype of {fp}. Skipping.")
406 continue
408 df = SpectrumDataFrame._ensure_experiment_name(
409 df, file_path=fp, add_source=add_source_file_column, add_spectrum_id=add_spectrum_id_column
410 )
412 df = SpectrumDataFrame._map_columns(df, column_mapping=column_mapping)
413 df = SpectrumDataFrame._cast_columns(df)
415 # If df > shard_size, split it up first
416 while len(df) > max_shard_size:
417 yield df[:max_shard_size]
418 df = df[max_shard_size:]
420 # Assumes df < shard_size
421 if current_shard is None:
422 current_shard = df
423 elif len(current_shard) + len(df) < max_shard_size:
424 current_shard = SpectrumDataFrame._concat_dataframes(current_shard, df)
425 else:
426 yield SpectrumDataFrame._concat_dataframes(current_shard, df[: (max_shard_size - len(current_shard))])
427 current_shard = df[(max_shard_size - len(current_shard)) :]
428 yield current_shard
430 def check_values(self, min_value: float, max_value: float, column_name: str) -> bool:
431 """Check the values of the DataFrame to ensure they are within the specified range.
433 Args:
434 min_value (float): The minimum value.
435 max_value (float): The maximum value.
436 column_name (str): The name of the column to check.
438 Returns:
439 bool: True if all values are within the range, False otherwise.
440 """
441 min_value = min_value or -np.inf
442 max_value = max_value or np.inf
443 if self._is_native:
444 for fp in self._file_paths:
445 out_of_range = (
446 pl.scan_parquet(fp).filter((pl.col(column_name) < min_value) | (pl.col(column_name) > max_value)).select(pl.len()).collect()
447 )
448 if out_of_range[0, 0] > 0:
449 return False
450 else:
451 assert self.df is not None
452 out_of_range = self.df.filter((pl.col(column_name) < min_value) | (pl.col(column_name) > max_value))
453 if out_of_range.height > 0:
454 return False
455 return True
457 def filter_rows(self, filter_fn: Callable) -> None:
458 """Apply a filter function to rows of the DataFrame.
460 Args:
461 filter_fn (Callable): Function used to filter rows.
462 """
463 if self._is_native:
464 for fp in self._file_paths:
465 df = pl.scan_parquet(fp).collect()
466 new_filter = df.select(
467 [
468 pl.struct(df.columns).map_elements(filter_fn, return_dtype=bool).alias("result"),
469 ]
470 )["result"]
471 self._filter_series_per_file[fp] &= new_filter
473 self._reset_current_file()
474 if not self._shuffle:
475 self._update_file_indices()
476 else:
477 assert self.df is not None
478 new_filter = self.df.select(
479 [
480 pl.struct(self.df.columns).map_elements(filter_fn, return_dtype=bool).alias("result"),
481 ]
482 )["result"]
483 self.df = self.df.filter(new_filter)
485 def _log(self, text: str) -> None:
486 if self._verbose:
487 logger.info(text)
489 def reset_filter(self) -> None:
490 """Reset the filters applied to the DataFrame."""
491 if not self._is_native:
492 raise NotImplementedError("Filter reset is not supported in non-native mode.")
493 self._filter_series_per_file = {fp: pl.Series(np.full(pl.scan_parquet(fp).collect().height, True, dtype=bool)) for fp in self._file_paths}
494 self._reset_current_file()
495 if not self._shuffle:
496 self._update_file_indices()
498 def _update_file_indices(self) -> None:
499 """Update mapping from index to file in native non-shuffle mode."""
500 if self._shuffle:
501 raise ValueError("Cannot use file indexing in shuffle mode.")
503 self._index_to_file_index = pl.concat(
504 [pl.Series(np.full(self._filter_series_per_file[fp].sum(), i, dtype=int)) for i, fp in enumerate(self._file_paths)]
505 )
507 cumulative_position = 0
508 self._file_begin_index: dict[str, int] = {}
509 for fp in self._file_paths:
510 self._file_begin_index[fp] = cumulative_position
511 height = self._filter_series_per_file[fp].sum()
512 cumulative_position += height
514 def _preshuffle_files(self) -> None:
515 """Shuffle across all files."""
516 if not self._is_native:
517 return
518 num_files = len(self._file_paths)
519 if num_files <= 1:
520 return
522 self._log(f"Pre-shuffling across {num_files:03,d} shards. This may take a while...")
524 self._log("Computing new mapping per original shard")
525 index_to_file_index = pl.concat(
526 [pl.Series(np.full(self._filter_series_per_file[fp].sum(), i, dtype=int)) for i, fp in enumerate(self._file_paths)]
527 )
529 # To ensure consistent shard sizes, we sample based on index permutations
530 index_to_file_index = pl.Series(np.random.permutation(index_to_file_index.to_numpy()))
532 offset = 0
533 mapping_per_file = {}
534 for fp in self._file_paths:
535 height = len(self._filter_series_per_file[fp])
536 mapping_per_file[fp] = index_to_file_index[offset : offset + height]
537 offset += height
539 if self._temp_directory is None:
540 self._temp_directory = tempfile.mkdtemp()
542 self._log("Extracting rows to create shuffled shards")
543 new_file_paths = []
544 start = time.time()
545 for i in range(num_files):
546 df = None
547 for fp in self._file_paths:
548 temp_df = pl.scan_parquet(fp).filter(mapping_per_file[fp] == i).collect()
549 if df is None:
550 df = temp_df
551 else:
552 df = SpectrumDataFrame._concat_dataframes(df, temp_df)
554 if df is None:
555 raise ValueError("No data in shard during reshuffle.")
557 temp_parquet_path = os.path.join(str(self._temp_directory), f"temp_{uuid.uuid4().hex}.parquet")
558 df.write_parquet(temp_parquet_path)
559 new_file_paths.append(temp_parquet_path)
561 delta = time.time() - start
562 est_total = delta / (i + 1) * (num_files - i - 1)
563 self._log(
564 f"Writing shuffled shard {i:03,d}/{num_files:03,d} to {temp_parquet_path} "
565 f"[{_format_time(delta)}/{_format_time(est_total)}, {(delta / (i + 1)):.3f}s/it]"
566 )
568 self._log("Removing unshuffled shards")
569 # Remove old temp files:
570 for fp in self._file_paths:
571 if os.path.commonpath([str(self._temp_directory), fp]) == self._temp_directory:
572 try:
573 os.remove(fp)
574 except OSError as e:
575 self._log(f"Error deleting temporary file {fp}: {e}")
577 self._file_paths = new_file_paths
578 self._filter_series_per_file = {fp: pl.Series(np.full(pl.scan_parquet(fp).collect().height, True, dtype=bool)) for fp in self._file_paths}
579 self._log("Pre-shuffle complete")
581 def _reset_current_file(self) -> None:
582 # Shuffled file handling, uses a two-step shuffle to optimise efficiency
583 self._current_index_in_file = 0 # index in the current file, used in shuffle mode
584 self._next_file_index = 0 # index of the next file to be loaded in _file_paths
585 self._current_file: str | None = None # filename of the current file
586 self._current_file_len = 0 # length of the current file
587 self._current_file_data: pl.DataFrame | None = None # loaded data of the current file
588 self._current_file_position = 0 # starting index of the current file, used to
590 def _shuffle_file_order(self) -> None:
591 """Shuffle the order of files in native mode."""
592 random.shuffle(self._file_paths)
594 def _load_parquet_data(self, file_path: str) -> pl.DataFrame:
595 """Load data from a parquet file and apply the filters."""
596 # if the experiment_name column is missing, we add it
597 df = pl.scan_parquet(file_path).filter(self._filter_series_per_file[file_path]).collect()
598 df = SpectrumDataFrame._ensure_experiment_name(
599 df,
600 file_path,
601 add_source=self._add_source_file_column,
602 force_source=False,
603 )
604 df = SpectrumDataFrame._cast_columns(df)
605 if self.preprocess_fn is not None:
606 df = self.preprocess_fn(df)
607 return df
609 def _load_next_file(self) -> None:
610 """Load the next file in sequence for lazy loading."""
611 # This function is exclusive to native mode i.e. always lazy
612 self._current_file = self._file_paths[self._next_file_index]
613 # Scan file, filter, and collect
614 if self._current_file == self._next_file and self._next_file_future is not None:
615 self._current_file_data = self._next_file_future
616 else:
617 self._current_file_data = self._load_parquet_data(self._current_file)
619 # Update next file loading
620 if self._shuffle:
621 self._current_file_data = SpectrumDataFrame._shuffle_df(self._current_file_data) # Shuffle rows
622 self._current_file_len = self._current_file_data.shape[0]
624 # Update future
625 if len(self._file_paths) > 0:
626 future_file_index = self._next_file_index + 1
627 if future_file_index >= len(self._file_paths):
628 if self._shuffle:
629 self._shuffle_file_order()
630 future_file_index = 0
632 self._next_file = self._file_paths[future_file_index]
634 self._next_file_future = None
635 self._start_preload_next(self._next_file)
637 def _start_preload_next(self, file_path: str) -> None:
638 """Start preloading the next file asynchronously."""
639 if self._preload_task is None or self._preload_task.done():
640 self._preload_task = self.loop.create_task(self._preload_next_file(file_path))
642 async def _preload_next_file(self, file_path: str) -> None:
643 """Asynchronously preload the next file."""
644 try:
645 self._next_file_future = await self.loop.run_in_executor(self.executor, self._load_parquet_data, file_path)
646 except Exception as e:
647 logger.warning(f"Error preloading file {file_path}: {e}")
648 self._next_file_future = None
650 def __len__(self) -> int:
651 """Returns the total number of rows in the SpectrumDataFrame.
653 Returns:
654 int: Number of rows in the DataFrame.
655 """
656 if self._is_native:
657 return sum([v.sum() for v in self._filter_series_per_file.values()])
658 assert self.df is not None
659 return int(self.df.shape[0])
661 def __getitem__(self, idx: int) -> dict[str, Any]:
662 """Return the item at the specified index.
664 Args:
665 idx (int): Index of the item to retrieve.
667 Returns:
668 dict[str, Any]: Dictionary containing the data from the specified row.
670 Raises:
671 IndexError: If the DataFrame is empty or the index is out of range.
672 """
673 length = len(self)
674 if length == 0:
675 raise IndexError("Attempt to index empty SpectrumDataFrame")
676 if idx >= length:
677 raise IndexError
679 # In shuffle, idx is ignored.
680 if self._is_native:
681 if self._shuffle:
682 # If no file is loaded or we have finished the current file
683 if self._current_file_data is None or self._current_index_in_file >= self._current_file_len:
684 self._current_index_in_file = 0
686 self._load_next_file()
687 self._next_file_index += 1
688 if self._next_file_index >= len(self._file_paths):
689 self._next_file_index = 0
691 # for mypy
692 assert self._current_file_data is not None
694 row = self._current_file_data[self._current_index_in_file]
696 self._current_index_in_file += 1
697 else:
698 # In native mode without shuffle, idx is used.
699 selected_file_index = self._index_to_file_index[idx]
701 # If the index is outside the currently loaded file, load the new file
702 if self._current_file_data is None or self._file_paths[selected_file_index] != self._current_file:
703 self._next_file_index = selected_file_index
704 self._load_next_file()
706 # for mypy
707 assert self._current_file is not None
708 assert self._current_file_data is not None
710 # Find the relative index within the current file
711 file_begin_index = self._file_begin_index[self._current_file]
712 index_in_file = idx
713 if file_begin_index > 0:
714 index_in_file = idx % self._file_begin_index[self._current_file]
716 row = self._current_file_data[index_in_file]
717 else:
718 assert self.df is not None
719 # We're in non-native non-lazy mode
720 if self._shuffle:
721 # Shuffle if we have passed through all entries
722 if self._current_index_in_file >= self.df.height:
723 self.df = SpectrumDataFrame._shuffle_df(self.df)
724 self._current_index_in_file = 0
726 row = self.df[self._current_index_in_file]
728 self._current_index_in_file += 1
729 else:
730 row = self.df[idx]
732 # row = SpectrumDataFrame._cast_columns(row)
734 # Squeeze all entries
735 row_dict: dict[str, Any] = {k: v[0] for k, v in row.to_dict(as_series=False).items()}
737 if self.is_annotated:
738 row_dict[ANNOTATED_COLUMN] = SpectrumDataFrame._sanitise_peptide(row_dict[ANNOTATED_COLUMN])
740 return row_dict
742 # def iterable(self, df: pl.DataFrame) -> None:
743 # """Iterates dataset. Supports streaming?"""
744 # pass
746 @property
747 def is_annotated(self) -> bool:
748 """Check if the dataset is annotated.
750 Returns:
751 bool: True if annotated, False otherwise.
752 """
753 return self._is_annotated
755 @property
756 def has_predictions(self) -> bool:
757 """Check if the dataset contains predictions.
759 Returns:
760 bool: True if predictions are present, False otherwise.
761 """
762 return self._has_predictions
764 @property
765 def is_lazy(self) -> bool:
766 """Check if lazy loading mode is enabled.
768 Returns:
769 bool: True if lazy loading is enabled, False otherwise.
770 """
771 return self._is_lazy
773 def save(
774 self,
775 target: Path,
776 partition: str | None = None,
777 name: str | None = None,
778 max_shard_size: int | None = None,
779 ) -> None:
780 """Save the dataset in parquet format with the option to partition and shard the data.
782 Args:
783 target: Directory to save the dataset.
784 partition: Partition name to be included in the file names.
785 name: Dataset name to be included in the file names.
786 max_shard_size: Maximum size of the data shards.
787 """
788 max_shard_size = max_shard_size or self._max_shard_size
789 partition = partition or "default"
790 name = name or "ms"
792 total_num_files = (len(self) // max_shard_size) + 1
794 shards = self._to_parquet_chunks(target, max_shard_size)
796 Path(target).mkdir(parents=True, exist_ok=True)
798 for i, shard in enumerate(shards):
799 filename = f"dataset-{name}-{partition}-{i:04d}-{total_num_files:04d}.parquet"
800 shard_path = os.path.join(target, filename)
801 self._log(f"Writing {shard_path}")
802 shard.write_parquet(shard_path)
804 def _to_parquet_chunks(self, target: Path, max_shard_size: int = 1_000_000) -> Iterator[pl.DataFrame]:
805 """Generate DataFrame chunks to be saved as parquet files.
807 Args:
808 target: Directory to save the parquet files.
809 max_shard_size: Maximum size of the data shards.
811 Yields:
812 Chunks of DataFrames to be saved.
813 """
814 if self._is_native:
815 current_shard = None
816 for fp in self._file_paths:
817 # Load each file with filtering
818 df = self._load_parquet_data(fp)
820 while len(df) > max_shard_size:
821 yield df[:max_shard_size]
822 df = df[max_shard_size:]
824 # Assumes df < shard_size
825 if current_shard is None:
826 current_shard = df
827 elif len(current_shard) + len(df) < max_shard_size:
828 current_shard = pl.concat([current_shard, df])
829 else:
830 yield pl.concat([current_shard, df[: (max_shard_size - len(current_shard))]])
831 current_shard = df[(max_shard_size - len(current_shard)) :]
832 yield current_shard
833 else:
834 assert self.df is not None
835 df = cast(pl.DataFrame, self.df)
836 while len(df) > max_shard_size:
837 yield df[:max_shard_size]
838 df = df[max_shard_size:]
839 yield df
841 def write_csv(self, target: str) -> None:
842 """Write the dataset to a CSV file.
844 Args:
845 target (str): Path to the output CSV file.
846 """
847 self.to_pandas().to_csv(target, index=False)
849 def write_ipc(self, target: str) -> None:
850 """Write the dataset to a Polars ipc file.
852 Args:
853 target (str): Path to the output ipc file.
854 """
855 df = self.to_polars()
856 if self._is_native:
857 df = df.collect().rechunk()
858 df.write_ipc(target)
860 def write_mgf(self, target: str, export_style: str | None = None) -> None:
861 """Write the dataset to an MGF file using Matchms format.
863 Args:
864 target (str): Path to the output MGF file.
865 export_style (str | None): Style of export to be used (optional).
866 """
867 export_style = export_style or "matchms"
868 spectra = self.to_matchms()
870 # Check if the file exists and delete it if it does
871 if os.path.exists(target):
872 try:
873 os.remove(target)
874 except OSError as e:
875 logger.warning(f"Error deleting existing file '{target}': {e}")
876 return # Exit the method if we can't delete the file
878 save_as_mgf(spectra, target, export_style=export_style)
880 def write_pointnovo(self, spectrum_source: str, feature_target: str) -> None:
881 """Write the dataset in PointNovo format.
883 Args:
884 spectrum_source (str): Source of the spectrum data.
885 feature_target (str): Target for the features.
886 """
887 raise NotImplementedError()
889 def write_mzxml(self, target: str) -> None:
890 """Write the dataset in mzXML format.
892 Args:
893 target (str): Path to the output mzXML file.
894 """
895 raise NotImplementedError()
897 def write_mzml(self, target: str) -> None:
898 """Write the dataset in mzML format.
900 Args:
901 target (str): Path to the output mzML file.
902 """
903 raise NotImplementedError()
905 def to_pandas(self) -> pd.DataFrame:
906 """Convert the dataset to a pandas DataFrame.
908 Warning:
909 This function loads the entire dataset into memory. For large datasets,
910 this may consume a significant amount of RAM.
912 Returns:
913 pd.DataFrame: The dataset in pandas DataFrame format.
914 """
915 return cast(pd.DataFrame, self.to_polars(return_lazy=False).to_pandas())
917 def to_polars(self, return_lazy: bool = True) -> pl.DataFrame | pl.LazyFrame:
918 """Convert the dataset to a polars DataFrame.
920 Args:
921 return_lazy (bool): Return LazyFrame when in lazy mode. Defaults to True.
923 Returns:
924 pl.DataFrame | pl.LazyFrame: The dataset in polars DataFrame format
925 """
926 if self._is_native:
927 if return_lazy:
928 dfs = []
929 for fp in self._file_paths:
930 dfs.append(pl.scan_parquet(fp).filter(self._filter_series_per_file[fp]))
931 df = pl.concat(dfs)
932 return df
934 df = None
935 for fp in self._file_paths:
936 temp_df = pl.scan_parquet(fp).filter(self._filter_series_per_file[fp]).collect()
937 temp_df = SpectrumDataFrame._ensure_experiment_name(temp_df, fp, add_source=self._add_source_file_column, force_source=True)
938 if df is None:
939 df = temp_df
940 else:
941 df = SpectrumDataFrame._concat_dataframes(df, temp_df)
942 return df
943 return self.df
945 @staticmethod
946 def _get_unified_schema(file_paths: list[str]) -> dict[str, Any]:
947 """Get the unified schema for a list of parquet files.
949 Args:
950 file_paths (list[str]): List of file paths to the dataset.
952 Returns:
953 dict[str, Any]: Unified schema for the dataset.
954 """
955 # Get union of all columns and their types
956 unified_features = {}
958 for file_path in file_paths:
959 # Read just the schema (no data) using scan_parquet
960 df_lazy = pl.scan_parquet(file_path)
961 schema = df_lazy.collect_schema()
963 for col_name, dtype in schema.items():
964 if col_name not in unified_features:
965 # Map Polars types to HuggingFace Features
966 if dtype == pl.String or dtype == pl.Utf8:
967 unified_features[col_name] = Value("string")
968 elif dtype == pl.Int64:
969 unified_features[col_name] = Value("int64")
970 elif dtype == pl.Float64:
971 unified_features[col_name] = Value("float64")
972 elif dtype == pl.Float32:
973 unified_features[col_name] = Value("float32")
974 elif dtype == pl.Int32:
975 unified_features[col_name] = Value("int32")
976 elif isinstance(dtype, pl.List):
977 # Handle list types
978 inner_type = dtype.inner
979 if inner_type == pl.Float64:
980 unified_features[col_name] = Sequence(Value("float64"))
981 elif inner_type == pl.Float32:
982 unified_features[col_name] = Sequence(Value("float32"))
983 elif inner_type == pl.Int64:
984 unified_features[col_name] = Sequence(Value("int64"))
985 elif inner_type == pl.String or inner_type == pl.Utf8:
986 unified_features[col_name] = Sequence(Value("string"))
987 else:
988 logger.warning(f"Unknown list inner type for {file_path} {col_name}: {inner_type}")
989 elif isinstance(dtype, pl.Array):
990 # Handle array types (fixed-size arrays)
991 inner_type = dtype.inner
992 if inner_type == pl.Float64:
993 unified_features[col_name] = Sequence(Value("float64"))
994 elif inner_type == pl.Float32:
995 unified_features[col_name] = Sequence(Value("float32"))
996 elif inner_type == pl.Int64:
997 unified_features[col_name] = Sequence(Value("int64"))
998 else:
999 logger.warning(f"Unknown array inner type for {file_path} {col_name}: {inner_type}")
1000 else:
1001 logger.warning(f"Unknown type for {file_path} {col_name}: {dtype}")
1003 return unified_features
1005 def to_dataset(
1006 self,
1007 in_memory: bool = False,
1008 force_unified_schema: bool = False,
1009 **kwargs: Any,
1010 ) -> Dataset:
1011 """Convert the dataset to a HuggingFace Dataset.
1013 Returns:
1014 Dataset: HuggingFace Dataset.
1015 """
1016 if self._is_native and not in_memory:
1017 if force_unified_schema:
1018 features = Features(self._get_unified_schema(self._file_paths))
1020 return load_dataset(
1021 "parquet",
1022 data_files=self._file_paths,
1023 streaming=True,
1024 split="train",
1025 features=features if force_unified_schema else None,
1026 verification_mode=VerificationMode.NO_CHECKS,
1027 **kwargs,
1028 )
1030 return Dataset.from_pandas(self.to_pandas(), **kwargs)
1032 def to_matchms(self) -> list[Spectrum]:
1033 """Convert the dataset to a list of Matchms spectrum objects.
1035 Warning:
1036 This function loads the entire dataset into memory. For large datasets,
1037 this may consume a significant amount of RAM.
1039 Returns:
1040 list[Spectrum]: List of Matchms spectrum objects.
1041 """
1042 df = self.to_polars(return_lazy=False)
1043 return SpectrumDataFrame._df_to_matchms(df)
1045 def export_predictions(self, target: str, export_type: str | Enum) -> None:
1046 """Export the predictions from the dataset.
1048 Args:
1049 target (str): Target path to save the predictions.
1050 export_type (str | Enum): Type of export format.
1051 """
1052 if isinstance(export_type, str):
1053 pass
1054 raise NotImplementedError()
1056 @classmethod
1057 def load(
1058 cls,
1059 source: str | list[str] | Path,
1060 source_type: str = "default",
1061 is_annotated: bool = False,
1062 shuffle: bool = False,
1063 name: str | None = None,
1064 partition: str | None = None,
1065 custom_load: Callable | None = None,
1066 column_mapping: dict[str, str] | None = None,
1067 lazy: bool = True,
1068 max_shard_size: int = 1_000_000,
1069 preshuffle_across_shards: bool = False,
1070 add_source_file_column: bool = False,
1071 preprocess_fn: Callable | None = None,
1072 add_spectrum_id: bool = False,
1073 force_spectrum_id: bool = False,
1074 force_convert_to_native: bool = False,
1075 verbose: bool = False,
1076 ) -> "SpectrumDataFrame":
1077 """Load a SpectrumDataFrame from a source.
1079 Args:
1080 source (str | Path): Path to the source file or directory.
1081 source_type (str): Type of the source (default is "default").
1082 is_annotated (bool): Whether the dataset is annotated.
1083 shuffle (bool): Whether to shuffle the dataset.
1084 name (str | None): Name of the dataset.
1085 partition (str | None): Partition name of the dataset.
1086 lazy (bool): Whether to use lazy loading mode.
1087 max_shard_size (int): Maximum size of data shards.
1088 preshuffle_across_shards (bool): Whether to perform a preshuffle across shards.
1089 add_source_file_column (bool): Whether to add the source file column.
1090 preprocess_fn (Callable | None): Preprocess function for the data on load.
1091 add_spectrum_id (bool): Whether to add spectrum id column.
1092 force_spectrum_id (bool): Force adding spectrum id column.
1093 force_convert_to_native (bool): Force conversion to native format when working
1094 with parquet files.
1095 verbose (bool): Whether to print verbose output.
1097 Returns:
1098 SpectrumDataFrame: The loaded SpectrumDataFrame.
1099 """
1100 partition = partition or "default"
1101 name = name or "ms"
1103 # Native mode
1104 if isinstance(source, str) and os.path.isdir(source) and source_type == "default":
1105 # /path/to/folder/dataset-name-train-0000-of-0001.parquet
1106 source = os.path.join(source, f"dataset-{name}-{partition}-*-*.parquet")
1108 return cls(
1109 file_paths=cast(str, source), # We don't support Path directly
1110 is_lazy=lazy,
1111 custom_load_fn=custom_load,
1112 column_mapping=column_mapping,
1113 max_shard_size=max_shard_size,
1114 shuffle=shuffle,
1115 is_annotated=is_annotated,
1116 preshuffle_across_shards=preshuffle_across_shards,
1117 add_source_file_column=add_source_file_column,
1118 preprocess_fn=preprocess_fn,
1119 add_spectrum_id=add_spectrum_id,
1120 force_spectrum_id=force_spectrum_id,
1121 force_convert_to_native=force_convert_to_native,
1122 verbose=verbose,
1123 )
1125 @staticmethod
1126 def _df_from_any(source: str, source_type: str | None = None) -> pl.DataFrame | None:
1127 """Load a DataFrame from various source formats (MGF, IPC, etc.).
1129 Args:
1130 source (str): Path to the source file.
1131 source_type (str | None): Type of the source file.
1133 Returns:
1134 pl.DataFrame: The loaded DataFrame.
1135 """
1136 if source_type is None:
1137 # Try to infer
1138 source_type = source.split(".")[-1].lower()
1140 match source_type:
1141 case "ipc":
1142 return SpectrumDataFrame._df_from_ipc(source)
1143 case "mgf":
1144 return SpectrumDataFrame._df_from_mgf(source)
1145 case "mzml":
1146 return SpectrumDataFrame._df_from_mzml(source)
1147 case "mzxml":
1148 return SpectrumDataFrame._df_from_mzxml(source)
1149 case "csv":
1150 return SpectrumDataFrame._df_from_csv(source)
1151 case "parquet":
1152 return SpectrumDataFrame._df_from_parquet(source)
1153 # case "_":
1155 return None
1157 @classmethod
1158 def load_mgf(cls, source: str) -> "SpectrumDataFrame":
1159 """Load a SpectrumDataFrame from an MGF file.
1161 Args:
1162 source (str): Path to the MGF file.
1164 Returns:
1165 SpectrumDataFrame: The loaded SpectrumDataFrame.
1166 """
1167 spectra = list(load_from_mgf(source))
1168 return cls.from_matchms(spectra)
1170 @staticmethod
1171 def _df_from_mgf(source: str) -> pl.DataFrame:
1172 """Load a polars DataFrame from an MGF file.
1174 Args:
1175 source (str): Path to the MGF file.
1177 Returns:
1178 pl.DataFrame: The loaded polars DataFrame.
1179 """
1180 spectra = list(load_from_mgf(source))
1181 return SpectrumDataFrame._df_from_matchms(spectra)
1183 @classmethod
1184 def load_pointnovo(cls, spectrum_source: str, feature_source: str) -> "SpectrumDataFrame":
1185 """Load a SpectrumDataFrame from PointNovo format.
1187 Args:
1188 spectrum_source (str): Source of spectrum data.
1189 feature_source (str): Source of feature data.
1190 """
1191 raise NotImplementedError()
1193 @classmethod
1194 def load_csv(
1195 cls,
1196 source: str,
1197 column_mapping: dict[str, str] | None = None,
1198 lazy: bool = False,
1199 annotated: bool = False,
1200 ) -> "SpectrumDataFrame":
1201 """Load a SpectrumDataFrame from a CSV file.
1203 Args:
1204 source (str): Path to the CSV file.
1205 column_mapping (dict[str, str] | None): Mapping of columns to rename.
1206 lazy (bool): Whether to use lazy loading mode.
1207 annotated (bool): Whether the dataset is annotated.
1209 Returns:
1210 SpectrumDataFrame: The loaded SpectrumDataFrame.
1211 """
1212 df = pl.read_csv(source)
1213 df = SpectrumDataFrame._map_columns(df, column_mapping=column_mapping)
1214 return cls(df, is_annotated=annotated, is_lazy=lazy)
1216 @classmethod
1217 def load_mzxml(cls, source: str) -> "SpectrumDataFrame":
1218 """Load a SpectrumDataFrame from an mzXML file.
1220 Args:
1221 source (str): Path to the mzXML file.
1223 Returns:
1224 SpectrumDataFrame: The loaded SpectrumDataFrame.
1225 """
1226 # spectra = list(load_from_mzxml(source))
1227 # return cls.from_matchms(spectra)
1228 df = SpectrumDataFrame._df_from_dict(read_mzxml(source))
1229 return cls.from_polars(df)
1231 @staticmethod
1232 def _df_from_mzxml(source: str) -> pl.DataFrame:
1233 """Load a polars DataFrame from an MGF file.
1235 Args:
1236 source (str): Path to the MGF file.
1238 Returns:
1239 pl.DataFrame: The loaded polars DataFrame.
1240 """
1241 # spectra = list(load_from_mzxml(source))
1242 # return SpectrumDataFrame._df_from_matchms(spectra)
1243 return SpectrumDataFrame._df_from_dict(read_mzxml(source))
1245 @classmethod
1246 def load_mzml(cls, source: str) -> "SpectrumDataFrame":
1247 """Load a SpectrumDataFrame from an mzML file.
1249 Args:
1250 source (str): Path to the mzML file.
1252 Returns:
1253 SpectrumDataFrame: The loaded SpectrumDataFrame.
1254 """
1255 # spectra = list(load_from_mzml(source))
1256 # return cls.from_matchms(spectra)
1257 df = SpectrumDataFrame._df_from_dict(read_mzml(source))
1258 return cls.from_polars(df)
1260 @staticmethod
1261 def _df_from_mzml(source: str) -> pl.DataFrame:
1262 """Load a polars DataFrame from an MGF file.
1264 Args:
1265 source (str): Path to the MGF file.
1267 Returns:
1268 pl.DataFrame: The loaded polars DataFrame.
1269 """
1270 # spectra = list(load_from_mzml(source))
1271 # return SpectrumDataFrame._df_from_matchms(spectra)
1272 return SpectrumDataFrame._df_from_dict(read_mzml(source))
1274 @classmethod
1275 def from_huggingface(
1276 cls,
1277 dataset: str | Dataset,
1278 shuffle: bool = False,
1279 is_annotated: bool = False,
1280 **kwargs: Any,
1281 ) -> "SpectrumDataFrame":
1282 """Load a SpectrumDataFrame from HuggingFace directory or Dataset instance.
1284 Warning:
1285 This function loads the entire dataset into memory. For large datasets,
1286 this may consume a significant amount of RAM.
1288 Args:
1289 dataset (str | Dataset): Path to HuggingFace or Dataset instance.
1291 Returns:
1292 SpectrumDataFrame: The loaded SpectrumDataFrame.
1293 """
1294 if isinstance(dataset, str):
1295 dataset = load_dataset(dataset, **kwargs)
1296 # TODO: Explore dataset.to_pandas(batched=True)
1297 return cls.from_pandas(dataset.to_pandas(), shuffle=shuffle, is_annotated=is_annotated)
1299 @classmethod
1300 def from_pandas(cls, df: pd.DataFrame, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame":
1301 """Create a SpectrumDataFrame from a pandas DataFrame.
1303 Args:
1304 df (pd.DataFrame): The pandas DataFrame.
1306 Returns:
1307 SpectrumDataFrame: The resulting SpectrumDataFrame.
1308 """
1309 df = pl.from_pandas(df)
1310 return cls.from_polars(df, shuffle=shuffle, is_annotated=is_annotated)
1312 @classmethod
1313 def from_polars(cls, df: pl.DataFrame, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame":
1314 """Create a SpectrumDataFrame from a polars DataFrame.
1316 Args:
1317 df (pl.DataFrame): The polars DataFrame.
1319 Returns:
1320 SpectrumDataFrame: The resulting SpectrumDataFrame.
1321 """
1322 return cls(
1323 df=df,
1324 shuffle=shuffle,
1325 is_annotated=is_annotated,
1326 )
1328 @classmethod
1329 def load_ipc(cls, source: str, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame":
1330 """Load a SpectrumDataFrame from IPC format.
1332 Args:
1333 source (str): Path to the IPC file.
1335 Returns:
1336 SpectrumDataFrame: The loaded SpectrumDataFrame.
1337 """
1338 df = cls._df_from_ipc(source)
1339 return cls(
1340 df,
1341 is_lazy=False,
1342 shuffle=shuffle,
1343 is_annotated=is_annotated,
1344 )
1346 @staticmethod
1347 def _df_from_ipc(source: str) -> pl.DataFrame:
1348 """Load a polars DataFrame from an IPC file.
1350 Args:
1351 source (str): Path to the IPC file.
1353 Returns:
1354 pl.DataFrame: The loaded polars DataFrame.
1355 """
1356 df = pl.read_ipc(source)
1357 if "modified_sequence" in df.columns:
1358 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN))
1359 return df
1361 @staticmethod
1362 def _df_from_csv(source: str) -> pl.DataFrame:
1363 """Load a polars DataFrame from a CSV file.
1365 Args:
1366 source (str): Path to the CSV file.
1368 Returns:
1369 pl.DataFrame: The loaded polars DataFrame.
1370 """
1371 df = pl.read_csv(source)
1372 if "modified_sequence" in df.columns:
1373 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN))
1374 return df
1376 @staticmethod
1377 def _df_from_parquet(source: str) -> pl.DataFrame:
1378 """Load a polars DataFrame from a parquet file.
1380 Args:
1381 source (str): Path to the parquet file.
1383 Returns:
1384 pl.DataFrame: The loaded polars DataFrame.
1385 """
1386 df = pl.read_parquet(source)
1387 if "modified_sequence" in df.columns:
1388 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN))
1389 return df
1391 @classmethod
1392 def from_matchms(cls, spectra: list[Spectrum], shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame":
1393 """Create a SpectrumDataFrame from Matchms spectrum objects.
1395 Args:
1396 spectra (list): List of Matchms spectrum objects.
1397 shuffle (bool, optional): If True, shuffle the data. Defaults to False.
1398 is_annotated (bool, optional): If True, treat the spectra as annotated.
1399 Defaults to False.
1401 Returns:
1402 SpectrumDataFrame: The resulting SpectrumDataFrame.
1404 Raises:
1405 ValueError: If the input parameters are invalid or incompatible.
1406 """
1407 df = SpectrumDataFrame._df_from_matchms(spectra)
1408 return cls(
1409 df=df,
1410 shuffle=shuffle,
1411 is_annotated=is_annotated,
1412 )
1414 @staticmethod
1415 def _parse_scan_number(scan_number: str, index: int) -> int | None:
1416 """Try parse scan number."""
1417 if scan_number.isdigit():
1418 return int(scan_number)
1420 # Use regex to extract the value after 'scan='
1421 match = re.search(r"scan=(\d+)", scan_number)
1422 if match:
1423 return int(match.group(1))
1425 # use index if scan number cannot be accessed
1426 return index
1428 @staticmethod
1429 def _df_from_dict(data: dict[str, Any]) -> pl.DataFrame:
1430 df = pl.DataFrame(
1431 {
1432 "scan_number": pl.Series(
1433 [SpectrumDataFrame._parse_scan_number(str(x), i) for i, x in enumerate(data["scan_number"])],
1434 dtype=pl.Int64,
1435 ),
1436 ANNOTATED_COLUMN: pl.Series(data["sequence"], dtype=pl.Utf8),
1437 # Calculate precursor mass
1438 MSColumns.PRECURSOR_MASS.value: pl.Series(
1439 np.array(data["precursor_mz"]) * np.array(data["precursor_charge"]) - np.array(data["precursor_charge"]) * PROTON_MASS_AMU,
1440 dtype=MS_TYPES[MSColumns.PRECURSOR_MASS],
1441 ),
1442 MSColumns.PRECURSOR_MZ.value: pl.Series(data["precursor_mz"], dtype=MS_TYPES[MSColumns.PRECURSOR_MZ]),
1443 MSColumns.PRECURSOR_CHARGE.value: pl.Series(data["precursor_charge"], dtype=MS_TYPES[MSColumns.PRECURSOR_CHARGE]),
1444 MSColumns.RETENTION_TIME.value: pl.Series(data["retention_time"], dtype=MS_TYPES[MSColumns.RETENTION_TIME]),
1445 MSColumns.MZ_ARRAY.value: pl.Series(data["mz_array"], dtype=MS_TYPES[MSColumns.MZ_ARRAY]),
1446 MSColumns.INTENSITY_ARRAY.value: pl.Series(data["intensity_array"], dtype=MS_TYPES[MSColumns.INTENSITY_ARRAY]),
1447 }
1448 )
1449 return df
1451 @staticmethod
1452 def _df_from_matchms(spectra: list[Spectrum]) -> pl.DataFrame:
1453 """Load a polars DataFrame from a list of Matchms spectra.
1455 Args:
1456 spectra (list[Spectrum]): List of Matchms spectrum objects.
1458 Returns:
1459 pl.DataFrame: The loaded polars DataFrame.
1460 """
1461 data: dict[str, list[Any]] = {
1462 "scan_number": [],
1463 "sequence": [],
1464 "precursor_mass": [],
1465 "precursor_mz": [],
1466 "precursor_charge": [],
1467 "retention_time": [],
1468 "mz_array": [],
1469 "intensity_array": [],
1470 }
1472 for i, spectrum in enumerate(spectra):
1473 data["scan_number"].append(i)
1474 data["sequence"].append(spectrum.metadata.get("peptide_sequence", ""))
1475 data["precursor_mass"].append(spectrum.metadata.get("pepmass", 0.0))
1476 data["precursor_mz"].append(spectrum.metadata.get("precursor_mz", 0.0))
1477 data["precursor_charge"].append(spectrum.metadata.get("charge", 0))
1478 data["retention_time"].append(spectrum.metadata.get("retention_time", 0.0))
1479 data["mz_array"].append(spectrum.peaks.mz)
1480 data["intensity_array"].append(spectrum.peaks.intensities)
1482 df = SpectrumDataFrame._df_from_dict(data)
1484 return df
1486 @staticmethod
1487 def _df_to_matchms(df: pl.DataFrame) -> list[Spectrum]:
1488 """Convert a polars DataFrame to a list of Matchms spectra.
1490 Args:
1491 df (pl.DataFrame): The input polars DataFrame.
1493 Returns:
1494 list[Spectrum]: List of Matchms spectrum objects.
1495 """
1496 spectra = []
1498 for row in df.iter_rows(named=True):
1499 metadata = {
1500 "peptide_sequence": row[ANNOTATED_COLUMN],
1501 # "pepmass": row[MSColumns.PRECURSOR_MASS.value],
1502 "precursor_mz": row[MSColumns.PRECURSOR_MZ.value],
1503 "charge": row[MSColumns.PRECURSOR_CHARGE.value],
1504 "retention_time": row[MSColumns.RETENTION_TIME.value],
1505 }
1507 mz_array = np.array(row[MSColumns.MZ_ARRAY.value])
1508 intensity_array = np.array(row[MSColumns.INTENSITY_ARRAY.value])
1510 spectrum = Spectrum(mz_array, intensity_array, metadata=metadata)
1511 spectra.append(spectrum)
1513 return spectra
1515 def _check_type_spec(self) -> None:
1516 """Check the data type specifications for the DataFrame columns.
1518 This method validates that important columns have the correct data types.
1519 """
1520 # Check expected columns, use ENUM constant for this.
1521 expected_cols = [
1522 c.value
1523 for c in [
1524 MSColumns.MZ_ARRAY,
1525 MSColumns.INTENSITY_ARRAY,
1526 MSColumns.PRECURSOR_MZ,
1527 MSColumns.PRECURSOR_CHARGE,
1528 ]
1529 ]
1530 if self.is_annotated:
1531 expected_cols.append(ANNOTATED_COLUMN)
1533 missing_cols = []
1534 if self._is_native:
1535 # Check all parquet files in parallel
1536 with ThreadPoolExecutor() as executor:
1537 # First check for missing columns
1538 def check_columns(file_path: str) -> list[str]:
1539 columns = pl.scan_parquet(file_path).collect_schema().keys()
1540 return [col for col in expected_cols if col not in columns]
1542 missing_cols_results = list(executor.map(check_columns, self._file_paths))
1543 missing_cols = next((cols for cols in missing_cols_results if cols), [])
1545 # If no missing columns and we need to check annotations
1546 if not missing_cols and self.is_annotated:
1548 def check_annotations(file_path: str) -> Any:
1549 return (
1550 pl.scan_parquet(file_path)
1551 .select(((pl.col(ANNOTATED_COLUMN).is_not_null()) & (pl.col(ANNOTATED_COLUMN) != "")).all())
1552 .collect()
1553 .to_numpy()[0]
1554 )
1556 has_annotations_results = list(executor.map(check_annotations, self._file_paths))
1557 if not all(has_annotations_results):
1558 raise ValueError(ANNOTATION_ERROR)
1559 else:
1560 assert self.df is not None
1561 # Check only self.df
1562 missing_cols = [col for col in expected_cols if col not in self.df.columns]
1563 if not missing_cols and self.is_annotated:
1564 has_annotations = self.df.select(((pl.col(ANNOTATED_COLUMN).is_not_null()) & (pl.col(ANNOTATED_COLUMN) != "")).all()).to_numpy()[0]
1565 if not has_annotations:
1566 raise ValueError(ANNOTATION_ERROR)
1568 if not missing_cols:
1569 # In non-native mode also check the charge column is not all zeros (DIA)
1570 if self.df is not None:
1571 if self.df[MSColumns.PRECURSOR_CHARGE.value].sum() == 0:
1572 logger.warning("The charge column is all zeros. This could indicate a DIA dataset or contains invalid values.")
1574 if missing_cols:
1575 plural_s = "s" if len(missing_cols) > 1 else ""
1576 missing_col_names = ", ".join(missing_cols)
1577 raise ValueError(f"Column{plural_s} missing! Missing column{plural_s}: {missing_col_names}")
1579 @classmethod
1580 def concatenate(cls, sdf: list["SpectrumDataFrame"], strict: bool = True) -> "SpectrumDataFrame":
1581 """Concatenate a list of SpectrumDataFrames.
1583 Warning:
1584 This function loads the entire dataset into memory. For large datasets,
1585 this may consume a significant amount of RAM.
1587 Args:
1588 df (list[SpectrumDataFrame]): List of SpectrumDataFrames to concatenate.
1589 strict (bool): Whether to perform strict concatenation.
1591 Returns:
1592 SpectrumDataFrame: The resulting concatenated SpectrumDataFrame.
1593 """
1594 return cls(pl.concat([x.to_polars(return_lazy=False) for x in sdf]))
1596 def get_unique_sequences(self) -> set[str]:
1597 """Retrieve unique peptide sequences from the dataset.
1599 Returns:
1600 set[str]: A set of unique peptide sequences.
1602 Raises:
1603 ValueError: If the dataset is not annotated.
1604 """
1605 if not self.is_annotated:
1606 raise ValueError("Only annotated datasets have sequences.")
1608 if self._is_native:
1609 sequences = set()
1610 for fp in self._file_paths:
1611 df_unique = pl.scan_parquet(fp).filter(self._filter_series_per_file[fp]).select(pl.col(ANNOTATED_COLUMN).unique()).collect()
1612 sequences.update(set(df_unique[ANNOTATED_COLUMN].to_list()))
1613 return sequences
1614 else:
1615 assert self.df is not None
1616 return set(self.df[ANNOTATED_COLUMN].unique())
1618 def get_vocabulary(self, tokenize_fn: Callable) -> set[str]:
1619 """Get the vocabulary of unique residues from peptide sequences.
1621 Args:
1622 tokenize_fn (Callable): Function to tokenize peptide sequences.
1624 Returns:
1625 set[str]: A set of unique residues.
1627 Raises:
1628 ValueError: If the dataset is not annotated.
1629 """
1630 if not self.is_annotated:
1631 raise ValueError("Only annotated datasets have residue vocabularies.")
1633 sequences = self.get_unique_sequences()
1634 residues = set()
1635 for x in sequences:
1636 residues.update(set(tokenize_fn(x)))
1637 return residues
1639 def validate_precursor_mass(self, metrics: Metrics, tolerance: float = 50) -> int:
1640 """Validate precursor mz matching the annotations.
1642 Args:
1643 metrics (Metrics): InstaNovo metrics class for calculating sequence mass.
1644 tolerance (float): Tolerance to match precursor mass in ppm.
1646 Returns:
1647 int: Number of precursor matches
1649 Raises:
1650 ValueError: If none of the sequences match the precursor mz.
1651 ValueError: If SpectrumDataFrame is not annotated.
1652 """
1653 if not self.is_annotated:
1654 raise ValueError("Cannot verify precursor mass without annotations.")
1656 if self._is_native:
1657 num_matches_precursor = 0
1658 for fp in self._file_paths:
1659 result = (
1660 pl.scan_parquet(fp)
1661 .filter(self._filter_series_per_file[fp])
1662 .select(
1663 [
1664 pl.col(ANNOTATED_COLUMN),
1665 pl.col(MSColumns.PRECURSOR_MZ.value),
1666 pl.col(MSColumns.PRECURSOR_CHARGE.value),
1667 ]
1668 )
1669 .with_columns(
1670 [
1671 pl.struct(
1672 [
1673 ANNOTATED_COLUMN,
1674 pl.col(MSColumns.PRECURSOR_MZ.value),
1675 pl.col(MSColumns.PRECURSOR_CHARGE.value),
1676 ]
1677 )
1678 .map_elements(
1679 lambda x: metrics.matches_precursor(
1680 x[ANNOTATED_COLUMN],
1681 x[MSColumns.PRECURSOR_MZ.value],
1682 x[MSColumns.PRECURSOR_CHARGE.value],
1683 prec_tol=tolerance,
1684 )[0],
1685 return_dtype=bool,
1686 )
1687 .alias("precursor_match")
1688 ]
1689 )
1690 .select(pl.col("precursor_match").sum().alias("num_matches_precursor"))
1691 )
1692 num_matches_precursor += result.collect()["num_matches_precursor"][0]
1693 else:
1694 assert self.df is not None
1695 result = (
1696 self.df.select(
1697 [
1698 pl.col(ANNOTATED_COLUMN),
1699 pl.col(MSColumns.PRECURSOR_MZ.value),
1700 pl.col(MSColumns.PRECURSOR_CHARGE.value),
1701 ]
1702 )
1703 .with_columns(
1704 [
1705 pl.struct(
1706 pl.col(ANNOTATED_COLUMN),
1707 pl.col(MSColumns.PRECURSOR_MZ.value),
1708 pl.col(MSColumns.PRECURSOR_CHARGE.value),
1709 )
1710 .map_elements(
1711 lambda x: metrics.matches_precursor(
1712 x[ANNOTATED_COLUMN],
1713 x[MSColumns.PRECURSOR_MZ.value],
1714 x[MSColumns.PRECURSOR_CHARGE.value],
1715 prec_tol=tolerance,
1716 )[0],
1717 return_dtype=bool,
1718 )
1719 .alias("precursor_match")
1720 ]
1721 )
1722 .select(pl.col("precursor_match").sum().alias("num_matches_precursor"))
1723 )
1724 num_matches_precursor = result["num_matches_precursor"][0]
1726 if num_matches_precursor == 0:
1727 raise ValueError("None of the sequence labels in the dataset match the precursor mz. Check sequences and residue set for errors.")
1728 elif num_matches_precursor < len(self):
1729 logger.warning(
1730 f"{len(self) - num_matches_precursor:,d} "
1731 f"({(1 - num_matches_precursor / len(self)) * 100:.2f}%) of the sequence labels "
1732 f"do not match the precursor mz to {tolerance}ppm."
1733 )
1735 return num_matches_precursor
1737 # def filter(self, *args, **kwargs) -> "SpectrumDataFrame":
1738 # return self.df.filter(*args, **kwargs)
1740 def sample_subset(self, fraction: float, seed: int | None = None) -> None:
1741 """Sample a subset of the dataset.
1743 Args:
1744 fraction (float): Fraction of the dataset to sample.
1745 seed (int): Random seed for reproducibility.
1746 """
1747 if fraction >= 1:
1748 return
1749 if self._is_native:
1750 for fp in self._file_paths:
1751 if seed:
1752 np.random.seed(seed)
1753 # TODO: variable "filter" is shadowing a python builtin
1754 filter = self._filter_series_per_file[fp].to_numpy() # noqa
1755 filter[filter] = np.random.choice([True, False], size=filter.sum(), p=[fraction, 1 - fraction])
1756 self._filter_series_per_file[fp] &= pl.Series(filter)
1757 if not self._shuffle:
1758 self._update_file_indices()
1759 self._reset_current_file()
1760 else:
1761 assert self.df is not None
1762 self.df = self.df.sample(fraction=fraction, seed=seed)
1764 def validate_data(self) -> bool:
1765 """Validate the integrity of the dataset.
1767 Returns:
1768 bool: True if the data is valid, False otherwise.
1769 """
1770 raise NotImplementedError()
1772 def __del__(self) -> None:
1773 """Clean up the resources when the object is destroyed.
1775 This includes shutting down the thread pool executor and removing temporary files.
1776 """
1777 if self.executor:
1778 self.executor.shutdown(wait=True)
1779 if self._temp_directory is not None and os.path.exists(self._temp_directory):
1780 shutil.rmtree(self._temp_directory)
1782 def _strip_shape_info(self, s: str) -> str:
1783 """Adjust the string representation of a SpectrumDataFrame for lazy loading.
1785 Strips shape information that would only correspond to the first temporary file.
1787 Parameters:
1788 s (str): The string representation of the SpectrumDataFrame.
1790 Returns:
1791 str: The adjusted string representation of the SpectrumDataFrame.
1792 """
1793 # Pattern to match rows and columns lines
1794 pattern = re.compile(r"(Rows:\s*\d+\s*Columns:\s*\d+)")
1795 # Replace the shape with a placeholder
1796 return pattern.sub("Shape: unknown in lazy loading mode.", s)
1798 @staticmethod
1799 def _truncate_list_repr(s: str, max_items: int = 3) -> str:
1800 """Find and truncate long lists within the SpectrumDataFrame string preview.
1802 Args:
1803 s (str): String representation of SpectrumDataFrame.
1804 max_items (int): Maximum number of list items to display at the list's head and tail.
1806 Returns:
1807 str: SpectrumDataFrame string representation with truncated list items, if necessary.
1808 """
1810 def process_list(match: re.Match) -> str:
1811 # Extract the list content
1812 list_content = match.group(1)
1813 # Convert to a Python list
1814 values = list(map(str.strip, list_content.split(",")))
1815 # Truncate if necessary
1816 if len(values) > 2 * max_items:
1817 truncated = values[:max_items] + ["..."] + values[-max_items:]
1818 else:
1819 truncated = values
1820 # Rebuild the list as a string
1821 return f"[{', '.join(truncated)}]"
1823 # Regex to find lists in the string
1824 list_pattern = re.compile(r"\[([^\]]*?)\]")
1825 # Apply truncation to all matches
1826 return list_pattern.sub(process_list, s)
1828 def _display_string_preview(self, df: Union[pl.DataFrame, "SpectrumDataFrame"]) -> str:
1829 """String preview of SpectrumDataFrame, truncating long list items.
1831 Args:
1832 df (Union(pl.DataFrame, "SpectrumDataFrame")): SpectrumDataFrame or Polars form of
1833 SpectrumDataFrame for string representation.
1835 Returns:
1836 str: String representation of SpectrumDataFrame.
1837 """
1838 if type(df) is not pl.DataFrame:
1839 df = df.to_polars(return_lazy=False)
1841 preview = df.glimpse(
1842 max_items_per_column=self.max_items_per_column,
1843 max_colname_length=self.max_colname_length,
1844 return_as_string=True,
1845 )
1847 if self.is_lazy:
1848 preview = self._strip_shape_info(preview)
1850 preview = self._truncate_list_repr(preview)
1852 return preview
1854 def __str__(self) -> str:
1855 """A user-friendly string representation of the SpectrumDataFrame object.
1857 Returns:
1858 str: String representation of SpectrumDataFrame.
1859 """
1860 # Metadata summary
1861 output = f"<SpectrumDataFrame | Lazy Loaded: {'Yes' if self.is_lazy else 'No'}>\n"
1863 if not self.is_lazy and self._file_paths == []:
1864 # Eager loading and non-parquet input case: the df is already loaded in memory
1865 output += self._display_string_preview(self.df)
1867 else:
1868 # Lazy loading and parquet cases: only load the first temp file.
1869 temp_sdf = self.load(self._file_paths[0])
1870 output += self._display_string_preview(temp_sdf)
1872 return output
1874 def __repr__(self) -> str:
1875 """An unambiguous string representation of the SpectrumDataFrame object.
1877 This method is intended to provide enough detail to reconstruct the SpectrumDataFrame
1878 object (if possible) or for debugging purposes. It includes all relevant attributes
1879 in a nested format.
1881 Returns:
1882 str: String representation of SpectrumDataFrame.
1883 """
1885 def adjust_indentation(input_string: str) -> str:
1886 """Adjusts the indentation of a multiline string based on the number of leading tabs.
1888 Args:
1889 input_string (str): The input string with potential leading tabs.
1891 Returns:
1892 str: A string with consistent indentation after each newline.
1893 """
1894 # Find the leading tabs at the beginning of the string
1895 match = re.match(r"^(\t*)", input_string)
1896 if match:
1897 leading_tabs = match.group(1) # Capture the leading tabs
1898 else:
1899 leading_tabs = ""
1901 # Add the leading tabs after every newline
1902 adjusted_string = input_string.replace("\n", "\n" + leading_tabs)
1904 return adjusted_string
1906 def pretty(d: dict, indent: int = 1) -> str:
1907 """Recursively formats a dictionary into a pretty indented string.
1909 Args:
1910 d (dict): The dictionary to format.
1911 indent (int): The current indentation level.
1913 Returns:
1914 str: A pretty-formatted string representation of the dictionary.
1915 """
1916 lines = []
1917 for key, value in d.items():
1918 lines.append("\t" * indent + str(key) + " =") # Add the key
1919 if isinstance(value, dict):
1920 # Recursively format nested dictionary
1921 lines.append(pretty(value, indent + 1))
1922 else:
1923 # Add the value
1924 next_rep = "\t" * (indent + 1) + str(value)
1925 if isinstance(value, pl.DataFrame) or isinstance(value, pl.Series):
1926 # Find how many tabs are at the front of the string,
1927 # and add that many tabs after every newline entry.
1928 next_rep = adjust_indentation(next_rep)
1930 lines.append(next_rep)
1932 return "\n".join(lines)
1934 class_name = self.__class__.__name__
1935 attributes = pretty(vars(self))
1936 rep = f"{class_name}(\n{attributes}\n)"
1938 return rep
1941def _format_time(seconds: float) -> str:
1942 seconds = int(seconds)
1943 return f"{seconds // 3600:02d}:{(seconds % 3600) // 60:02d}:{seconds % 60:02d}"