Coverage for instanovo/utils/data_handler.py: 61%

798 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import glob 

5import os 

6import random 

7import re 

8import shutil 

9import tempfile 

10import time 

11import uuid 

12from concurrent.futures import ThreadPoolExecutor 

13from enum import Enum 

14from itertools import chain 

15from pathlib import Path 

16from typing import TYPE_CHECKING, Any, Callable, Iterator, Union, cast 

17 

18import numpy as np 

19import pandas as pd 

20import polars as pl 

21from datasets import Dataset, Features, Sequence, Value, VerificationMode, load_dataset 

22from matchms import Spectrum 

23from matchms.exporting import save_as_mgf 

24from matchms.importing import load_from_mgf 

25 

26from instanovo.constants import ( 

27 ANNOTATED_COLUMN, 

28 ANNOTATION_ERROR, 

29 MS_TYPES, 

30 PROTON_MASS_AMU, 

31 MSColumns, 

32) 

33 

34if TYPE_CHECKING: 

35 from instanovo.utils.metrics import Metrics 

36from instanovo.__init__ import console 

37from instanovo.utils.colorlogging import ColorLog 

38from instanovo.utils.msreader import read_mzml, read_mzxml 

39 

40logger = ColorLog(console, __name__).logger 

41 

42 

43class SpectrumDataFrame: 

44 """Spectra data class. 

45 

46 A data class for interacting with mass spectra data, including loading, processing, 

47 and saving mass spectrometry data in various formats such as Parquet, MGF, and others. 

48 

49 Supports lazy loading, shuffling, and handling of large datasets by processing them 

50 in chunks. 

51 

52 Attributes: 

53 is_annotated: Whether the dataset is annotated with peptide sequences. 

54 has_predictions: Whether the dataset contains predictions. 

55 is_lazy: Whether lazy loading mode is used. 

56 """ 

57 

58 def __init__( 

59 self, 

60 df: pl.DataFrame | None = None, 

61 file_paths: str | list[str] | None = None, 

62 is_annotated: bool = False, 

63 has_predictions: bool = False, 

64 is_lazy: bool = False, 

65 shuffle: bool = False, 

66 preshuffle_across_shards: bool = True, 

67 max_shard_size: int = 100_000, 

68 custom_load_fn: Callable | None = None, 

69 column_mapping: dict[str, str] | None = None, 

70 add_source_file_column: bool = False, 

71 preprocess_fn: Callable | None = None, 

72 force_convert_to_native: bool = False, 

73 verbose: bool = False, 

74 add_spectrum_id: bool = False, 

75 force_spectrum_id: bool = False, 

76 ) -> None: 

77 """Initialize SpectrumDataFrame. 

78 

79 Args: 

80 df (pl.DataFrame | None): In-memory polars DataFrame with mass spectra data. 

81 file_paths (str | list[str] | None): Path(s) to the input data files. 

82 custom_load_fn (Callable | None): Custom function for loading data files. 

83 is_annotated (bool): Whether the dataset is annotated. 

84 has_predictions (bool): Whether predictions are present in the dataset. 

85 is_lazy (bool): Whether to use lazy loading mode. 

86 shuffle (bool): Whether to shuffle the dataset. 

87 max_shard_size (int): Maximum size of data shards for chunking large datasets. 

88 add_source_file_column (bool): Add source file column to the data. 

89 preprocess_fn (Callable | None): Preprocess function for the data on load. 

90 force_convert_to_native (bool): Force conversion to native format. 

91 verbose (bool): Whether to print verbose output. 

92 add_spectrum_id (bool): Add spectrum id column to the data. 

93 force_spectrum_id (bool): Force addition of spectrum id column to the data. 

94 

95 Raises: 

96 ValueError: If neither `df` nor `file_paths` is specified, or both are given. 

97 FileNotFoundError: If no files are found matching the given `file_paths`. 

98 """ 

99 self._is_annotated: bool = is_annotated 

100 self._has_predictions: bool = has_predictions # if we have outputs/predictions 

101 self._is_lazy: bool = is_lazy # or streaming 

102 self._shuffle: bool = shuffle 

103 self._max_shard_size = max_shard_size 

104 self._custom_load_fn = custom_load_fn 

105 self._add_source_file_column = add_source_file_column 

106 self._verbose = verbose 

107 self._add_spectrum_id = add_spectrum_id 

108 self._force_spectrum_id = force_spectrum_id 

109 self.executor = None 

110 self._temp_directory = None 

111 # String representation values: 

112 self.max_items_per_column = 3 

113 self.max_colname_length = 20 

114 self.preprocess_fn = preprocess_fn 

115 self.df: pl.DataFrame | None = None 

116 

117 if df is None and file_paths is None: 

118 raise ValueError("Must specify either df or file_paths, both are None.") 

119 

120 # native refers to data being stored as a list of parquet files. 

121 self._is_native = file_paths is not None 

122 if df is not None and self._is_native: 

123 raise ValueError("Must specify either df or file_paths, not both.") 

124 

125 if self._is_native: 

126 # Get all file paths 

127 self._file_paths = SpectrumDataFrame._convert_file_paths(cast(str | list[str], file_paths)) 

128 

129 if len(self._file_paths) == 0: 

130 raise FileNotFoundError(f"No files matching '{file_paths}' were found.") 

131 

132 # If any of the files are not .parquet, create a tempdir with the converted files. 

133 if not all(fp.lower().endswith(".parquet") for fp in self._file_paths) or force_convert_to_native: 

134 # If lazy make tempdir, if not convert to non-native and load all contents into df 

135 # Only iterate over non-parquet files 

136 df_iterator = SpectrumDataFrame.get_data_shards( 

137 self._file_paths, 

138 max_shard_size=self._max_shard_size, 

139 custom_load_fn=custom_load_fn, 

140 column_mapping=column_mapping, 

141 add_source_file_column=add_source_file_column, 

142 force_convert_to_native=force_convert_to_native, 

143 add_spectrum_id_column=add_spectrum_id, 

144 verbose=verbose, 

145 ) 

146 

147 if self._is_lazy: 

148 self._temp_directory = tempfile.mkdtemp() 

149 

150 new_file_paths = [fp for fp in self._file_paths if (fp.lower().endswith(".parquet") and not force_convert_to_native)] 

151 for temp_df in df_iterator: 

152 # TODO: better way to generate id than hash? 

153 temp_parquet_path = os.path.join(self._temp_directory, f"temp_{uuid.uuid4().hex}.parquet") 

154 temp_df.write_parquet(temp_parquet_path) 

155 new_file_paths.append(temp_parquet_path) 

156 self._log(f"Saving temporary file to {temp_parquet_path}") 

157 self._file_paths = new_file_paths 

158 else: 

159 self.df = None 

160 for temp_df in df_iterator: 

161 temp_df = SpectrumDataFrame._map_columns(temp_df, column_mapping=column_mapping) 

162 temp_df = SpectrumDataFrame._cast_columns(temp_df) 

163 if self.df is None: 

164 self.df = temp_df 

165 else: 

166 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df) 

167 

168 # Ensure parquet files are re-added 

169 for fp in self._file_paths: 

170 if not fp.lower().endswith(".parquet") or force_convert_to_native: 

171 continue 

172 temp_df = SpectrumDataFrame._map_columns(pl.read_parquet(fp), column_mapping=column_mapping) 

173 temp_df = SpectrumDataFrame._cast_columns(temp_df) 

174 temp_df = SpectrumDataFrame._ensure_experiment_name( 

175 temp_df, 

176 fp, 

177 add_source=add_source_file_column, 

178 force_source=True, 

179 add_spectrum_id=add_spectrum_id, 

180 force_spectrum_id=force_spectrum_id, 

181 ) 

182 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df) 

183 

184 # Native is disabled if not lazy 

185 self._is_native = False 

186 self._file_paths = [] 

187 elif not self._is_lazy: 

188 # Loaded native, convert to lazy 

189 self.df = None 

190 for fp in self._file_paths: 

191 temp_df = SpectrumDataFrame._map_columns(pl.read_parquet(fp), column_mapping=column_mapping) 

192 temp_df = SpectrumDataFrame._cast_columns(temp_df) 

193 temp_df = SpectrumDataFrame._ensure_experiment_name( 

194 temp_df, 

195 fp, 

196 add_source=add_source_file_column, 

197 force_source=True, 

198 add_spectrum_id=add_spectrum_id, 

199 force_spectrum_id=force_spectrum_id, 

200 ) 

201 if self.df is None: 

202 self.df = temp_df 

203 else: 

204 self.df = SpectrumDataFrame._concat_dataframes(self.df, temp_df) 

205 

206 # Native is disabled if not lazy 

207 self._is_native = False 

208 self._file_paths = [] 

209 

210 # Create filter series for native mode 

211 if self._file_paths is not None: 

212 with ThreadPoolExecutor() as executor: 

213 # Process files in parallel 

214 self._filter_series_per_file = dict(executor.map(SpectrumDataFrame._create_filter_series, self._file_paths)) 

215 else: 

216 self.df = df 

217 

218 # Check all columns 

219 self._log("Verifying loaded data") 

220 self._check_type_spec() 

221 self._reset_current_file() 

222 

223 if preprocess_fn is not None and not self._is_lazy: 

224 self._log("Preprocessing data in memory.") 

225 self.df = preprocess_fn(self.df) 

226 

227 if self._shuffle: 

228 if self._is_native: 

229 if preshuffle_across_shards: 

230 self._preshuffle_files() 

231 self._shuffle_file_order() 

232 else: 

233 self.df = SpectrumDataFrame._shuffle_df(self.df) 

234 elif self._is_native: 

235 # Sort files alphabetically 

236 # TODO: Do we want this? Keeps consistent when loading files across devices 

237 self._file_paths.sort() 

238 self._update_file_indices() 

239 

240 if self._is_lazy: 

241 # When lazy loading, use async loaders to keep next file ready at all times. 

242 self.executor = ThreadPoolExecutor(max_workers=1) 

243 self.loop = asyncio.get_event_loop() 

244 self._next_file: str | None = None # Future file name 

245 self._next_file_future: pl.DataFrame | None = None # Future file data 

246 self._preload_task: asyncio.Task | None = None 

247 

248 @staticmethod 

249 def _create_filter_series(file_path: str) -> tuple[str, pl.Series]: 

250 # Use lazy evaluation to get height without loading full data 

251 height = pl.scan_parquet(file_path).select(pl.len()).collect().item() 

252 return file_path, pl.Series(np.ones(height, dtype=bool)) 

253 

254 @staticmethod 

255 def _shuffle_df(df: pl.DataFrame) -> pl.DataFrame: 

256 """Shuffle the rows of the given DataFrame.""" 

257 shuffled_indices = np.random.permutation(len(df)) 

258 return df[shuffled_indices] 

259 # return df.with_row_count("row_nr").select(pl.all().shuffle()) 

260 

261 @staticmethod 

262 def _sanitise_peptide(peptide: str) -> str: 

263 """Sanitise peptide sequence.""" 

264 # Some datasets save sequence wrapped with _ or . 

265 if peptide is None: 

266 return None 

267 if peptide[0] == "_" and peptide[-1] == "_": 

268 peptide = peptide[1:-1] 

269 if peptide[0] == "." and peptide[-1] == ".": 

270 peptide = peptide[1:-1] 

271 return peptide 

272 

273 @staticmethod 

274 def _ensure_experiment_name( 

275 df: pl.DataFrame, 

276 file_path: str, 

277 add_source: bool = False, 

278 force_source: bool = False, 

279 add_spectrum_id: bool = False, 

280 force_spectrum_id: bool = False, 

281 ) -> pl.DataFrame: 

282 """Ensure experiment_name is a column in the df.""" 

283 if "experiment_name" not in df.columns: 

284 exp_name = Path(file_path).stem 

285 df = df.with_columns(pl.lit(exp_name).alias("experiment_name").cast(pl.Utf8)) 

286 if add_source and (("source_file" not in df.columns) or force_source): 

287 df = df.with_columns(pl.lit(file_path).alias("source_file").cast(pl.Utf8)) 

288 if add_spectrum_id and (("spectrum_id" not in df.columns) or force_spectrum_id): 

289 if "scan_number" not in df.columns: 

290 logger.warning("Scan number column not found. Creating row-based idx column for spectrum_id.") 

291 df = df.with_row_index("idx") 

292 df = df.with_columns((pl.col("experiment_name") + ":" + pl.col("idx").cast(pl.Utf8)).alias("spectrum_id")) 

293 df = df.drop("idx") 

294 else: 

295 df = df.with_columns((pl.col("experiment_name") + ":" + pl.col("scan_number").cast(pl.Utf8)).alias("spectrum_id")) 

296 return df 

297 

298 @staticmethod 

299 def _map_columns(df: pl.DataFrame, column_mapping: dict[str, str] | None = None) -> pl.DataFrame: 

300 """Map the columns of the DataFrame to the appropriate data types based on MS_TYPES.""" 

301 if column_mapping is None: 

302 return df 

303 return df.rename({k: v for k, v in column_mapping.items() if k in df.columns}) 

304 

305 @staticmethod 

306 def _cast_columns(df: pl.DataFrame) -> pl.DataFrame: 

307 """Cast the columns of the DataFrame to the appropriate data types based on MS_TYPES.""" 

308 return df.with_columns([pl.col(column.value).cast(dtype) for column, dtype in MS_TYPES.items() if column.value in df.columns]) 

309 

310 @staticmethod 

311 def _is_glob(path: str) -> bool: 

312 return "*" in path or "?" in path or "[" in path 

313 

314 @staticmethod 

315 def _convert_file_paths(file_paths: str | list[str]) -> list[str]: 

316 """Convert a string or list of file paths to a list of file paths. 

317 

318 Args: 

319 file_paths (str | list[str]): File path or list of file paths. 

320 

321 Returns: 

322 list[str]: A list of resolved file paths. 

323 

324 Raises: 

325 ValueError: If input is a directory or not a valid file path. 

326 """ 

327 if isinstance(file_paths, str): 

328 if os.path.isdir(file_paths): 

329 raise ValueError("Input must be a string (filepath or glob) or a list of file paths. Found directory.") 

330 if SpectrumDataFrame._is_glob(file_paths): 

331 # Glob notation: path/to/data/*.parquet 

332 return glob.glob(file_paths) 

333 else: 

334 # Single file 

335 return [file_paths] 

336 elif not isinstance(file_paths, list): 

337 ValueError("Input must be a string (filepath or glob) or a list of file paths.") 

338 

339 # Expand if list of globs 

340 file_paths = list(chain.from_iterable([glob.glob(path) if SpectrumDataFrame._is_glob(path) else [path] for path in file_paths])) 

341 

342 return file_paths 

343 

344 @staticmethod 

345 def _concat_dataframes(df1: pl.DataFrame, df2: pl.DataFrame) -> pl.DataFrame: 

346 df1_columns = set(df1.columns) 

347 df2_columns = set(df2.columns) 

348 

349 # Find missing columns in both DataFrames 

350 missing_in_df1 = df2_columns - df1_columns 

351 missing_in_df2 = df1_columns - df2_columns 

352 

353 # Add missing columns to df1 with None values 

354 for col in missing_in_df1: 

355 df1 = df1.with_columns(pl.lit(None).cast(df2[col].dtype).alias(col)) 

356 

357 # Add missing columns to df2 with None values 

358 for col in missing_in_df2: 

359 df2 = df2.with_columns(pl.lit(None).cast(df1[col].dtype).alias(col)) 

360 

361 # Rearrange df2 to have the same order as df1 

362 df2 = df2.select(df1.columns) 

363 

364 return pl.concat([df1, df2], how="vertical_relaxed") 

365 

366 @staticmethod 

367 def get_data_shards( 

368 file_paths: list[str], 

369 custom_load_fn: Callable | None = None, 

370 column_mapping: dict[str, str] | None = None, 

371 max_shard_size: int = 100_000, 

372 add_source_file_column: bool = False, 

373 force_convert_to_native: bool = False, 

374 add_spectrum_id_column: bool = False, 

375 verbose: bool = False, 

376 ) -> Iterator[pl.DataFrame]: 

377 """Load data files into DataFrames one at a time to save memory. 

378 

379 Args: 

380 file_paths (list[str]): List of file paths to be loaded. 

381 custom_load_fn (Callable | None): Custom function to load the files. 

382 max_shard_size (int): Maximum size of data shards. 

383 verbose (bool): Whether to using logger 

384 force_convert_to_native (bool): Force conversion to native format. 

385 add_spectrum_id_column (bool): Whether to add spectrum id column. 

386 

387 Yields: 

388 Iterator[pl.DataFrame]: DataFrames containing mass spectra data. 

389 """ 

390 column_mapping = column_mapping or {} 

391 current_shard = None 

392 for i, fp in enumerate(file_paths, 1): 

393 if verbose: 

394 logger.info(f"Loading file {i:03,d} of {len(file_paths):03,d}: {fp}") 

395 

396 if fp.endswith(".parquet") and not force_convert_to_native: 

397 continue 

398 

399 if custom_load_fn is not None: 

400 df = custom_load_fn(fp) 

401 else: 

402 df = SpectrumDataFrame._df_from_any(fp) 

403 

404 if df is None: 

405 logger.warning(f"Unknown filetype of {fp}. Skipping.") 

406 continue 

407 

408 df = SpectrumDataFrame._ensure_experiment_name( 

409 df, file_path=fp, add_source=add_source_file_column, add_spectrum_id=add_spectrum_id_column 

410 ) 

411 

412 df = SpectrumDataFrame._map_columns(df, column_mapping=column_mapping) 

413 df = SpectrumDataFrame._cast_columns(df) 

414 

415 # If df > shard_size, split it up first 

416 while len(df) > max_shard_size: 

417 yield df[:max_shard_size] 

418 df = df[max_shard_size:] 

419 

420 # Assumes df < shard_size 

421 if current_shard is None: 

422 current_shard = df 

423 elif len(current_shard) + len(df) < max_shard_size: 

424 current_shard = SpectrumDataFrame._concat_dataframes(current_shard, df) 

425 else: 

426 yield SpectrumDataFrame._concat_dataframes(current_shard, df[: (max_shard_size - len(current_shard))]) 

427 current_shard = df[(max_shard_size - len(current_shard)) :] 

428 yield current_shard 

429 

430 def check_values(self, min_value: float, max_value: float, column_name: str) -> bool: 

431 """Check the values of the DataFrame to ensure they are within the specified range. 

432 

433 Args: 

434 min_value (float): The minimum value. 

435 max_value (float): The maximum value. 

436 column_name (str): The name of the column to check. 

437 

438 Returns: 

439 bool: True if all values are within the range, False otherwise. 

440 """ 

441 min_value = min_value or -np.inf 

442 max_value = max_value or np.inf 

443 if self._is_native: 

444 for fp in self._file_paths: 

445 out_of_range = ( 

446 pl.scan_parquet(fp).filter((pl.col(column_name) < min_value) | (pl.col(column_name) > max_value)).select(pl.len()).collect() 

447 ) 

448 if out_of_range[0, 0] > 0: 

449 return False 

450 else: 

451 assert self.df is not None 

452 out_of_range = self.df.filter((pl.col(column_name) < min_value) | (pl.col(column_name) > max_value)) 

453 if out_of_range.height > 0: 

454 return False 

455 return True 

456 

457 def filter_rows(self, filter_fn: Callable) -> None: 

458 """Apply a filter function to rows of the DataFrame. 

459 

460 Args: 

461 filter_fn (Callable): Function used to filter rows. 

462 """ 

463 if self._is_native: 

464 for fp in self._file_paths: 

465 df = pl.scan_parquet(fp).collect() 

466 new_filter = df.select( 

467 [ 

468 pl.struct(df.columns).map_elements(filter_fn, return_dtype=bool).alias("result"), 

469 ] 

470 )["result"] 

471 self._filter_series_per_file[fp] &= new_filter 

472 

473 self._reset_current_file() 

474 if not self._shuffle: 

475 self._update_file_indices() 

476 else: 

477 assert self.df is not None 

478 new_filter = self.df.select( 

479 [ 

480 pl.struct(self.df.columns).map_elements(filter_fn, return_dtype=bool).alias("result"), 

481 ] 

482 )["result"] 

483 self.df = self.df.filter(new_filter) 

484 

485 def _log(self, text: str) -> None: 

486 if self._verbose: 

487 logger.info(text) 

488 

489 def reset_filter(self) -> None: 

490 """Reset the filters applied to the DataFrame.""" 

491 if not self._is_native: 

492 raise NotImplementedError("Filter reset is not supported in non-native mode.") 

493 self._filter_series_per_file = {fp: pl.Series(np.full(pl.scan_parquet(fp).collect().height, True, dtype=bool)) for fp in self._file_paths} 

494 self._reset_current_file() 

495 if not self._shuffle: 

496 self._update_file_indices() 

497 

498 def _update_file_indices(self) -> None: 

499 """Update mapping from index to file in native non-shuffle mode.""" 

500 if self._shuffle: 

501 raise ValueError("Cannot use file indexing in shuffle mode.") 

502 

503 self._index_to_file_index = pl.concat( 

504 [pl.Series(np.full(self._filter_series_per_file[fp].sum(), i, dtype=int)) for i, fp in enumerate(self._file_paths)] 

505 ) 

506 

507 cumulative_position = 0 

508 self._file_begin_index: dict[str, int] = {} 

509 for fp in self._file_paths: 

510 self._file_begin_index[fp] = cumulative_position 

511 height = self._filter_series_per_file[fp].sum() 

512 cumulative_position += height 

513 

514 def _preshuffle_files(self) -> None: 

515 """Shuffle across all files.""" 

516 if not self._is_native: 

517 return 

518 num_files = len(self._file_paths) 

519 if num_files <= 1: 

520 return 

521 

522 self._log(f"Pre-shuffling across {num_files:03,d} shards. This may take a while...") 

523 

524 self._log("Computing new mapping per original shard") 

525 index_to_file_index = pl.concat( 

526 [pl.Series(np.full(self._filter_series_per_file[fp].sum(), i, dtype=int)) for i, fp in enumerate(self._file_paths)] 

527 ) 

528 

529 # To ensure consistent shard sizes, we sample based on index permutations 

530 index_to_file_index = pl.Series(np.random.permutation(index_to_file_index.to_numpy())) 

531 

532 offset = 0 

533 mapping_per_file = {} 

534 for fp in self._file_paths: 

535 height = len(self._filter_series_per_file[fp]) 

536 mapping_per_file[fp] = index_to_file_index[offset : offset + height] 

537 offset += height 

538 

539 if self._temp_directory is None: 

540 self._temp_directory = tempfile.mkdtemp() 

541 

542 self._log("Extracting rows to create shuffled shards") 

543 new_file_paths = [] 

544 start = time.time() 

545 for i in range(num_files): 

546 df = None 

547 for fp in self._file_paths: 

548 temp_df = pl.scan_parquet(fp).filter(mapping_per_file[fp] == i).collect() 

549 if df is None: 

550 df = temp_df 

551 else: 

552 df = SpectrumDataFrame._concat_dataframes(df, temp_df) 

553 

554 if df is None: 

555 raise ValueError("No data in shard during reshuffle.") 

556 

557 temp_parquet_path = os.path.join(str(self._temp_directory), f"temp_{uuid.uuid4().hex}.parquet") 

558 df.write_parquet(temp_parquet_path) 

559 new_file_paths.append(temp_parquet_path) 

560 

561 delta = time.time() - start 

562 est_total = delta / (i + 1) * (num_files - i - 1) 

563 self._log( 

564 f"Writing shuffled shard {i:03,d}/{num_files:03,d} to {temp_parquet_path} " 

565 f"[{_format_time(delta)}/{_format_time(est_total)}, {(delta / (i + 1)):.3f}s/it]" 

566 ) 

567 

568 self._log("Removing unshuffled shards") 

569 # Remove old temp files: 

570 for fp in self._file_paths: 

571 if os.path.commonpath([str(self._temp_directory), fp]) == self._temp_directory: 

572 try: 

573 os.remove(fp) 

574 except OSError as e: 

575 self._log(f"Error deleting temporary file {fp}: {e}") 

576 

577 self._file_paths = new_file_paths 

578 self._filter_series_per_file = {fp: pl.Series(np.full(pl.scan_parquet(fp).collect().height, True, dtype=bool)) for fp in self._file_paths} 

579 self._log("Pre-shuffle complete") 

580 

581 def _reset_current_file(self) -> None: 

582 # Shuffled file handling, uses a two-step shuffle to optimise efficiency 

583 self._current_index_in_file = 0 # index in the current file, used in shuffle mode 

584 self._next_file_index = 0 # index of the next file to be loaded in _file_paths 

585 self._current_file: str | None = None # filename of the current file 

586 self._current_file_len = 0 # length of the current file 

587 self._current_file_data: pl.DataFrame | None = None # loaded data of the current file 

588 self._current_file_position = 0 # starting index of the current file, used to 

589 

590 def _shuffle_file_order(self) -> None: 

591 """Shuffle the order of files in native mode.""" 

592 random.shuffle(self._file_paths) 

593 

594 def _load_parquet_data(self, file_path: str) -> pl.DataFrame: 

595 """Load data from a parquet file and apply the filters.""" 

596 # if the experiment_name column is missing, we add it 

597 df = pl.scan_parquet(file_path).filter(self._filter_series_per_file[file_path]).collect() 

598 df = SpectrumDataFrame._ensure_experiment_name( 

599 df, 

600 file_path, 

601 add_source=self._add_source_file_column, 

602 force_source=False, 

603 ) 

604 df = SpectrumDataFrame._cast_columns(df) 

605 if self.preprocess_fn is not None: 

606 df = self.preprocess_fn(df) 

607 return df 

608 

609 def _load_next_file(self) -> None: 

610 """Load the next file in sequence for lazy loading.""" 

611 # This function is exclusive to native mode i.e. always lazy 

612 self._current_file = self._file_paths[self._next_file_index] 

613 # Scan file, filter, and collect 

614 if self._current_file == self._next_file and self._next_file_future is not None: 

615 self._current_file_data = self._next_file_future 

616 else: 

617 self._current_file_data = self._load_parquet_data(self._current_file) 

618 

619 # Update next file loading 

620 if self._shuffle: 

621 self._current_file_data = SpectrumDataFrame._shuffle_df(self._current_file_data) # Shuffle rows 

622 self._current_file_len = self._current_file_data.shape[0] 

623 

624 # Update future 

625 if len(self._file_paths) > 0: 

626 future_file_index = self._next_file_index + 1 

627 if future_file_index >= len(self._file_paths): 

628 if self._shuffle: 

629 self._shuffle_file_order() 

630 future_file_index = 0 

631 

632 self._next_file = self._file_paths[future_file_index] 

633 

634 self._next_file_future = None 

635 self._start_preload_next(self._next_file) 

636 

637 def _start_preload_next(self, file_path: str) -> None: 

638 """Start preloading the next file asynchronously.""" 

639 if self._preload_task is None or self._preload_task.done(): 

640 self._preload_task = self.loop.create_task(self._preload_next_file(file_path)) 

641 

642 async def _preload_next_file(self, file_path: str) -> None: 

643 """Asynchronously preload the next file.""" 

644 try: 

645 self._next_file_future = await self.loop.run_in_executor(self.executor, self._load_parquet_data, file_path) 

646 except Exception as e: 

647 logger.warning(f"Error preloading file {file_path}: {e}") 

648 self._next_file_future = None 

649 

650 def __len__(self) -> int: 

651 """Returns the total number of rows in the SpectrumDataFrame. 

652 

653 Returns: 

654 int: Number of rows in the DataFrame. 

655 """ 

656 if self._is_native: 

657 return sum([v.sum() for v in self._filter_series_per_file.values()]) 

658 assert self.df is not None 

659 return int(self.df.shape[0]) 

660 

661 def __getitem__(self, idx: int) -> dict[str, Any]: 

662 """Return the item at the specified index. 

663 

664 Args: 

665 idx (int): Index of the item to retrieve. 

666 

667 Returns: 

668 dict[str, Any]: Dictionary containing the data from the specified row. 

669 

670 Raises: 

671 IndexError: If the DataFrame is empty or the index is out of range. 

672 """ 

673 length = len(self) 

674 if length == 0: 

675 raise IndexError("Attempt to index empty SpectrumDataFrame") 

676 if idx >= length: 

677 raise IndexError 

678 

679 # In shuffle, idx is ignored. 

680 if self._is_native: 

681 if self._shuffle: 

682 # If no file is loaded or we have finished the current file 

683 if self._current_file_data is None or self._current_index_in_file >= self._current_file_len: 

684 self._current_index_in_file = 0 

685 

686 self._load_next_file() 

687 self._next_file_index += 1 

688 if self._next_file_index >= len(self._file_paths): 

689 self._next_file_index = 0 

690 

691 # for mypy 

692 assert self._current_file_data is not None 

693 

694 row = self._current_file_data[self._current_index_in_file] 

695 

696 self._current_index_in_file += 1 

697 else: 

698 # In native mode without shuffle, idx is used. 

699 selected_file_index = self._index_to_file_index[idx] 

700 

701 # If the index is outside the currently loaded file, load the new file 

702 if self._current_file_data is None or self._file_paths[selected_file_index] != self._current_file: 

703 self._next_file_index = selected_file_index 

704 self._load_next_file() 

705 

706 # for mypy 

707 assert self._current_file is not None 

708 assert self._current_file_data is not None 

709 

710 # Find the relative index within the current file 

711 file_begin_index = self._file_begin_index[self._current_file] 

712 index_in_file = idx 

713 if file_begin_index > 0: 

714 index_in_file = idx % self._file_begin_index[self._current_file] 

715 

716 row = self._current_file_data[index_in_file] 

717 else: 

718 assert self.df is not None 

719 # We're in non-native non-lazy mode 

720 if self._shuffle: 

721 # Shuffle if we have passed through all entries 

722 if self._current_index_in_file >= self.df.height: 

723 self.df = SpectrumDataFrame._shuffle_df(self.df) 

724 self._current_index_in_file = 0 

725 

726 row = self.df[self._current_index_in_file] 

727 

728 self._current_index_in_file += 1 

729 else: 

730 row = self.df[idx] 

731 

732 # row = SpectrumDataFrame._cast_columns(row) 

733 

734 # Squeeze all entries 

735 row_dict: dict[str, Any] = {k: v[0] for k, v in row.to_dict(as_series=False).items()} 

736 

737 if self.is_annotated: 

738 row_dict[ANNOTATED_COLUMN] = SpectrumDataFrame._sanitise_peptide(row_dict[ANNOTATED_COLUMN]) 

739 

740 return row_dict 

741 

742 # def iterable(self, df: pl.DataFrame) -> None: 

743 # """Iterates dataset. Supports streaming?""" 

744 # pass 

745 

746 @property 

747 def is_annotated(self) -> bool: 

748 """Check if the dataset is annotated. 

749 

750 Returns: 

751 bool: True if annotated, False otherwise. 

752 """ 

753 return self._is_annotated 

754 

755 @property 

756 def has_predictions(self) -> bool: 

757 """Check if the dataset contains predictions. 

758 

759 Returns: 

760 bool: True if predictions are present, False otherwise. 

761 """ 

762 return self._has_predictions 

763 

764 @property 

765 def is_lazy(self) -> bool: 

766 """Check if lazy loading mode is enabled. 

767 

768 Returns: 

769 bool: True if lazy loading is enabled, False otherwise. 

770 """ 

771 return self._is_lazy 

772 

773 def save( 

774 self, 

775 target: Path, 

776 partition: str | None = None, 

777 name: str | None = None, 

778 max_shard_size: int | None = None, 

779 ) -> None: 

780 """Save the dataset in parquet format with the option to partition and shard the data. 

781 

782 Args: 

783 target: Directory to save the dataset. 

784 partition: Partition name to be included in the file names. 

785 name: Dataset name to be included in the file names. 

786 max_shard_size: Maximum size of the data shards. 

787 """ 

788 max_shard_size = max_shard_size or self._max_shard_size 

789 partition = partition or "default" 

790 name = name or "ms" 

791 

792 total_num_files = (len(self) // max_shard_size) + 1 

793 

794 shards = self._to_parquet_chunks(target, max_shard_size) 

795 

796 Path(target).mkdir(parents=True, exist_ok=True) 

797 

798 for i, shard in enumerate(shards): 

799 filename = f"dataset-{name}-{partition}-{i:04d}-{total_num_files:04d}.parquet" 

800 shard_path = os.path.join(target, filename) 

801 self._log(f"Writing {shard_path}") 

802 shard.write_parquet(shard_path) 

803 

804 def _to_parquet_chunks(self, target: Path, max_shard_size: int = 1_000_000) -> Iterator[pl.DataFrame]: 

805 """Generate DataFrame chunks to be saved as parquet files. 

806 

807 Args: 

808 target: Directory to save the parquet files. 

809 max_shard_size: Maximum size of the data shards. 

810 

811 Yields: 

812 Chunks of DataFrames to be saved. 

813 """ 

814 if self._is_native: 

815 current_shard = None 

816 for fp in self._file_paths: 

817 # Load each file with filtering 

818 df = self._load_parquet_data(fp) 

819 

820 while len(df) > max_shard_size: 

821 yield df[:max_shard_size] 

822 df = df[max_shard_size:] 

823 

824 # Assumes df < shard_size 

825 if current_shard is None: 

826 current_shard = df 

827 elif len(current_shard) + len(df) < max_shard_size: 

828 current_shard = pl.concat([current_shard, df]) 

829 else: 

830 yield pl.concat([current_shard, df[: (max_shard_size - len(current_shard))]]) 

831 current_shard = df[(max_shard_size - len(current_shard)) :] 

832 yield current_shard 

833 else: 

834 assert self.df is not None 

835 df = cast(pl.DataFrame, self.df) 

836 while len(df) > max_shard_size: 

837 yield df[:max_shard_size] 

838 df = df[max_shard_size:] 

839 yield df 

840 

841 def write_csv(self, target: str) -> None: 

842 """Write the dataset to a CSV file. 

843 

844 Args: 

845 target (str): Path to the output CSV file. 

846 """ 

847 self.to_pandas().to_csv(target, index=False) 

848 

849 def write_ipc(self, target: str) -> None: 

850 """Write the dataset to a Polars ipc file. 

851 

852 Args: 

853 target (str): Path to the output ipc file. 

854 """ 

855 df = self.to_polars() 

856 if self._is_native: 

857 df = df.collect().rechunk() 

858 df.write_ipc(target) 

859 

860 def write_mgf(self, target: str, export_style: str | None = None) -> None: 

861 """Write the dataset to an MGF file using Matchms format. 

862 

863 Args: 

864 target (str): Path to the output MGF file. 

865 export_style (str | None): Style of export to be used (optional). 

866 """ 

867 export_style = export_style or "matchms" 

868 spectra = self.to_matchms() 

869 

870 # Check if the file exists and delete it if it does 

871 if os.path.exists(target): 

872 try: 

873 os.remove(target) 

874 except OSError as e: 

875 logger.warning(f"Error deleting existing file '{target}': {e}") 

876 return # Exit the method if we can't delete the file 

877 

878 save_as_mgf(spectra, target, export_style=export_style) 

879 

880 def write_pointnovo(self, spectrum_source: str, feature_target: str) -> None: 

881 """Write the dataset in PointNovo format. 

882 

883 Args: 

884 spectrum_source (str): Source of the spectrum data. 

885 feature_target (str): Target for the features. 

886 """ 

887 raise NotImplementedError() 

888 

889 def write_mzxml(self, target: str) -> None: 

890 """Write the dataset in mzXML format. 

891 

892 Args: 

893 target (str): Path to the output mzXML file. 

894 """ 

895 raise NotImplementedError() 

896 

897 def write_mzml(self, target: str) -> None: 

898 """Write the dataset in mzML format. 

899 

900 Args: 

901 target (str): Path to the output mzML file. 

902 """ 

903 raise NotImplementedError() 

904 

905 def to_pandas(self) -> pd.DataFrame: 

906 """Convert the dataset to a pandas DataFrame. 

907 

908 Warning: 

909 This function loads the entire dataset into memory. For large datasets, 

910 this may consume a significant amount of RAM. 

911 

912 Returns: 

913 pd.DataFrame: The dataset in pandas DataFrame format. 

914 """ 

915 return cast(pd.DataFrame, self.to_polars(return_lazy=False).to_pandas()) 

916 

917 def to_polars(self, return_lazy: bool = True) -> pl.DataFrame | pl.LazyFrame: 

918 """Convert the dataset to a polars DataFrame. 

919 

920 Args: 

921 return_lazy (bool): Return LazyFrame when in lazy mode. Defaults to True. 

922 

923 Returns: 

924 pl.DataFrame | pl.LazyFrame: The dataset in polars DataFrame format 

925 """ 

926 if self._is_native: 

927 if return_lazy: 

928 dfs = [] 

929 for fp in self._file_paths: 

930 dfs.append(pl.scan_parquet(fp).filter(self._filter_series_per_file[fp])) 

931 df = pl.concat(dfs) 

932 return df 

933 

934 df = None 

935 for fp in self._file_paths: 

936 temp_df = pl.scan_parquet(fp).filter(self._filter_series_per_file[fp]).collect() 

937 temp_df = SpectrumDataFrame._ensure_experiment_name(temp_df, fp, add_source=self._add_source_file_column, force_source=True) 

938 if df is None: 

939 df = temp_df 

940 else: 

941 df = SpectrumDataFrame._concat_dataframes(df, temp_df) 

942 return df 

943 return self.df 

944 

945 @staticmethod 

946 def _get_unified_schema(file_paths: list[str]) -> dict[str, Any]: 

947 """Get the unified schema for a list of parquet files. 

948 

949 Args: 

950 file_paths (list[str]): List of file paths to the dataset. 

951 

952 Returns: 

953 dict[str, Any]: Unified schema for the dataset. 

954 """ 

955 # Get union of all columns and their types 

956 unified_features = {} 

957 

958 for file_path in file_paths: 

959 # Read just the schema (no data) using scan_parquet 

960 df_lazy = pl.scan_parquet(file_path) 

961 schema = df_lazy.collect_schema() 

962 

963 for col_name, dtype in schema.items(): 

964 if col_name not in unified_features: 

965 # Map Polars types to HuggingFace Features 

966 if dtype == pl.String or dtype == pl.Utf8: 

967 unified_features[col_name] = Value("string") 

968 elif dtype == pl.Int64: 

969 unified_features[col_name] = Value("int64") 

970 elif dtype == pl.Float64: 

971 unified_features[col_name] = Value("float64") 

972 elif dtype == pl.Float32: 

973 unified_features[col_name] = Value("float32") 

974 elif dtype == pl.Int32: 

975 unified_features[col_name] = Value("int32") 

976 elif isinstance(dtype, pl.List): 

977 # Handle list types 

978 inner_type = dtype.inner 

979 if inner_type == pl.Float64: 

980 unified_features[col_name] = Sequence(Value("float64")) 

981 elif inner_type == pl.Float32: 

982 unified_features[col_name] = Sequence(Value("float32")) 

983 elif inner_type == pl.Int64: 

984 unified_features[col_name] = Sequence(Value("int64")) 

985 elif inner_type == pl.String or inner_type == pl.Utf8: 

986 unified_features[col_name] = Sequence(Value("string")) 

987 else: 

988 logger.warning(f"Unknown list inner type for {file_path} {col_name}: {inner_type}") 

989 elif isinstance(dtype, pl.Array): 

990 # Handle array types (fixed-size arrays) 

991 inner_type = dtype.inner 

992 if inner_type == pl.Float64: 

993 unified_features[col_name] = Sequence(Value("float64")) 

994 elif inner_type == pl.Float32: 

995 unified_features[col_name] = Sequence(Value("float32")) 

996 elif inner_type == pl.Int64: 

997 unified_features[col_name] = Sequence(Value("int64")) 

998 else: 

999 logger.warning(f"Unknown array inner type for {file_path} {col_name}: {inner_type}") 

1000 else: 

1001 logger.warning(f"Unknown type for {file_path} {col_name}: {dtype}") 

1002 

1003 return unified_features 

1004 

1005 def to_dataset( 

1006 self, 

1007 in_memory: bool = False, 

1008 force_unified_schema: bool = False, 

1009 **kwargs: Any, 

1010 ) -> Dataset: 

1011 """Convert the dataset to a HuggingFace Dataset. 

1012 

1013 Returns: 

1014 Dataset: HuggingFace Dataset. 

1015 """ 

1016 if self._is_native and not in_memory: 

1017 if force_unified_schema: 

1018 features = Features(self._get_unified_schema(self._file_paths)) 

1019 

1020 return load_dataset( 

1021 "parquet", 

1022 data_files=self._file_paths, 

1023 streaming=True, 

1024 split="train", 

1025 features=features if force_unified_schema else None, 

1026 verification_mode=VerificationMode.NO_CHECKS, 

1027 **kwargs, 

1028 ) 

1029 

1030 return Dataset.from_pandas(self.to_pandas(), **kwargs) 

1031 

1032 def to_matchms(self) -> list[Spectrum]: 

1033 """Convert the dataset to a list of Matchms spectrum objects. 

1034 

1035 Warning: 

1036 This function loads the entire dataset into memory. For large datasets, 

1037 this may consume a significant amount of RAM. 

1038 

1039 Returns: 

1040 list[Spectrum]: List of Matchms spectrum objects. 

1041 """ 

1042 df = self.to_polars(return_lazy=False) 

1043 return SpectrumDataFrame._df_to_matchms(df) 

1044 

1045 def export_predictions(self, target: str, export_type: str | Enum) -> None: 

1046 """Export the predictions from the dataset. 

1047 

1048 Args: 

1049 target (str): Target path to save the predictions. 

1050 export_type (str | Enum): Type of export format. 

1051 """ 

1052 if isinstance(export_type, str): 

1053 pass 

1054 raise NotImplementedError() 

1055 

1056 @classmethod 

1057 def load( 

1058 cls, 

1059 source: str | list[str] | Path, 

1060 source_type: str = "default", 

1061 is_annotated: bool = False, 

1062 shuffle: bool = False, 

1063 name: str | None = None, 

1064 partition: str | None = None, 

1065 custom_load: Callable | None = None, 

1066 column_mapping: dict[str, str] | None = None, 

1067 lazy: bool = True, 

1068 max_shard_size: int = 1_000_000, 

1069 preshuffle_across_shards: bool = False, 

1070 add_source_file_column: bool = False, 

1071 preprocess_fn: Callable | None = None, 

1072 add_spectrum_id: bool = False, 

1073 force_spectrum_id: bool = False, 

1074 force_convert_to_native: bool = False, 

1075 verbose: bool = False, 

1076 ) -> "SpectrumDataFrame": 

1077 """Load a SpectrumDataFrame from a source. 

1078 

1079 Args: 

1080 source (str | Path): Path to the source file or directory. 

1081 source_type (str): Type of the source (default is "default"). 

1082 is_annotated (bool): Whether the dataset is annotated. 

1083 shuffle (bool): Whether to shuffle the dataset. 

1084 name (str | None): Name of the dataset. 

1085 partition (str | None): Partition name of the dataset. 

1086 lazy (bool): Whether to use lazy loading mode. 

1087 max_shard_size (int): Maximum size of data shards. 

1088 preshuffle_across_shards (bool): Whether to perform a preshuffle across shards. 

1089 add_source_file_column (bool): Whether to add the source file column. 

1090 preprocess_fn (Callable | None): Preprocess function for the data on load. 

1091 add_spectrum_id (bool): Whether to add spectrum id column. 

1092 force_spectrum_id (bool): Force adding spectrum id column. 

1093 force_convert_to_native (bool): Force conversion to native format when working 

1094 with parquet files. 

1095 verbose (bool): Whether to print verbose output. 

1096 

1097 Returns: 

1098 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1099 """ 

1100 partition = partition or "default" 

1101 name = name or "ms" 

1102 

1103 # Native mode 

1104 if isinstance(source, str) and os.path.isdir(source) and source_type == "default": 

1105 # /path/to/folder/dataset-name-train-0000-of-0001.parquet 

1106 source = os.path.join(source, f"dataset-{name}-{partition}-*-*.parquet") 

1107 

1108 return cls( 

1109 file_paths=cast(str, source), # We don't support Path directly 

1110 is_lazy=lazy, 

1111 custom_load_fn=custom_load, 

1112 column_mapping=column_mapping, 

1113 max_shard_size=max_shard_size, 

1114 shuffle=shuffle, 

1115 is_annotated=is_annotated, 

1116 preshuffle_across_shards=preshuffle_across_shards, 

1117 add_source_file_column=add_source_file_column, 

1118 preprocess_fn=preprocess_fn, 

1119 add_spectrum_id=add_spectrum_id, 

1120 force_spectrum_id=force_spectrum_id, 

1121 force_convert_to_native=force_convert_to_native, 

1122 verbose=verbose, 

1123 ) 

1124 

1125 @staticmethod 

1126 def _df_from_any(source: str, source_type: str | None = None) -> pl.DataFrame | None: 

1127 """Load a DataFrame from various source formats (MGF, IPC, etc.). 

1128 

1129 Args: 

1130 source (str): Path to the source file. 

1131 source_type (str | None): Type of the source file. 

1132 

1133 Returns: 

1134 pl.DataFrame: The loaded DataFrame. 

1135 """ 

1136 if source_type is None: 

1137 # Try to infer 

1138 source_type = source.split(".")[-1].lower() 

1139 

1140 match source_type: 

1141 case "ipc": 

1142 return SpectrumDataFrame._df_from_ipc(source) 

1143 case "mgf": 

1144 return SpectrumDataFrame._df_from_mgf(source) 

1145 case "mzml": 

1146 return SpectrumDataFrame._df_from_mzml(source) 

1147 case "mzxml": 

1148 return SpectrumDataFrame._df_from_mzxml(source) 

1149 case "csv": 

1150 return SpectrumDataFrame._df_from_csv(source) 

1151 case "parquet": 

1152 return SpectrumDataFrame._df_from_parquet(source) 

1153 # case "_": 

1154 

1155 return None 

1156 

1157 @classmethod 

1158 def load_mgf(cls, source: str) -> "SpectrumDataFrame": 

1159 """Load a SpectrumDataFrame from an MGF file. 

1160 

1161 Args: 

1162 source (str): Path to the MGF file. 

1163 

1164 Returns: 

1165 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1166 """ 

1167 spectra = list(load_from_mgf(source)) 

1168 return cls.from_matchms(spectra) 

1169 

1170 @staticmethod 

1171 def _df_from_mgf(source: str) -> pl.DataFrame: 

1172 """Load a polars DataFrame from an MGF file. 

1173 

1174 Args: 

1175 source (str): Path to the MGF file. 

1176 

1177 Returns: 

1178 pl.DataFrame: The loaded polars DataFrame. 

1179 """ 

1180 spectra = list(load_from_mgf(source)) 

1181 return SpectrumDataFrame._df_from_matchms(spectra) 

1182 

1183 @classmethod 

1184 def load_pointnovo(cls, spectrum_source: str, feature_source: str) -> "SpectrumDataFrame": 

1185 """Load a SpectrumDataFrame from PointNovo format. 

1186 

1187 Args: 

1188 spectrum_source (str): Source of spectrum data. 

1189 feature_source (str): Source of feature data. 

1190 """ 

1191 raise NotImplementedError() 

1192 

1193 @classmethod 

1194 def load_csv( 

1195 cls, 

1196 source: str, 

1197 column_mapping: dict[str, str] | None = None, 

1198 lazy: bool = False, 

1199 annotated: bool = False, 

1200 ) -> "SpectrumDataFrame": 

1201 """Load a SpectrumDataFrame from a CSV file. 

1202 

1203 Args: 

1204 source (str): Path to the CSV file. 

1205 column_mapping (dict[str, str] | None): Mapping of columns to rename. 

1206 lazy (bool): Whether to use lazy loading mode. 

1207 annotated (bool): Whether the dataset is annotated. 

1208 

1209 Returns: 

1210 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1211 """ 

1212 df = pl.read_csv(source) 

1213 df = SpectrumDataFrame._map_columns(df, column_mapping=column_mapping) 

1214 return cls(df, is_annotated=annotated, is_lazy=lazy) 

1215 

1216 @classmethod 

1217 def load_mzxml(cls, source: str) -> "SpectrumDataFrame": 

1218 """Load a SpectrumDataFrame from an mzXML file. 

1219 

1220 Args: 

1221 source (str): Path to the mzXML file. 

1222 

1223 Returns: 

1224 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1225 """ 

1226 # spectra = list(load_from_mzxml(source)) 

1227 # return cls.from_matchms(spectra) 

1228 df = SpectrumDataFrame._df_from_dict(read_mzxml(source)) 

1229 return cls.from_polars(df) 

1230 

1231 @staticmethod 

1232 def _df_from_mzxml(source: str) -> pl.DataFrame: 

1233 """Load a polars DataFrame from an MGF file. 

1234 

1235 Args: 

1236 source (str): Path to the MGF file. 

1237 

1238 Returns: 

1239 pl.DataFrame: The loaded polars DataFrame. 

1240 """ 

1241 # spectra = list(load_from_mzxml(source)) 

1242 # return SpectrumDataFrame._df_from_matchms(spectra) 

1243 return SpectrumDataFrame._df_from_dict(read_mzxml(source)) 

1244 

1245 @classmethod 

1246 def load_mzml(cls, source: str) -> "SpectrumDataFrame": 

1247 """Load a SpectrumDataFrame from an mzML file. 

1248 

1249 Args: 

1250 source (str): Path to the mzML file. 

1251 

1252 Returns: 

1253 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1254 """ 

1255 # spectra = list(load_from_mzml(source)) 

1256 # return cls.from_matchms(spectra) 

1257 df = SpectrumDataFrame._df_from_dict(read_mzml(source)) 

1258 return cls.from_polars(df) 

1259 

1260 @staticmethod 

1261 def _df_from_mzml(source: str) -> pl.DataFrame: 

1262 """Load a polars DataFrame from an MGF file. 

1263 

1264 Args: 

1265 source (str): Path to the MGF file. 

1266 

1267 Returns: 

1268 pl.DataFrame: The loaded polars DataFrame. 

1269 """ 

1270 # spectra = list(load_from_mzml(source)) 

1271 # return SpectrumDataFrame._df_from_matchms(spectra) 

1272 return SpectrumDataFrame._df_from_dict(read_mzml(source)) 

1273 

1274 @classmethod 

1275 def from_huggingface( 

1276 cls, 

1277 dataset: str | Dataset, 

1278 shuffle: bool = False, 

1279 is_annotated: bool = False, 

1280 **kwargs: Any, 

1281 ) -> "SpectrumDataFrame": 

1282 """Load a SpectrumDataFrame from HuggingFace directory or Dataset instance. 

1283 

1284 Warning: 

1285 This function loads the entire dataset into memory. For large datasets, 

1286 this may consume a significant amount of RAM. 

1287 

1288 Args: 

1289 dataset (str | Dataset): Path to HuggingFace or Dataset instance. 

1290 

1291 Returns: 

1292 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1293 """ 

1294 if isinstance(dataset, str): 

1295 dataset = load_dataset(dataset, **kwargs) 

1296 # TODO: Explore dataset.to_pandas(batched=True) 

1297 return cls.from_pandas(dataset.to_pandas(), shuffle=shuffle, is_annotated=is_annotated) 

1298 

1299 @classmethod 

1300 def from_pandas(cls, df: pd.DataFrame, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame": 

1301 """Create a SpectrumDataFrame from a pandas DataFrame. 

1302 

1303 Args: 

1304 df (pd.DataFrame): The pandas DataFrame. 

1305 

1306 Returns: 

1307 SpectrumDataFrame: The resulting SpectrumDataFrame. 

1308 """ 

1309 df = pl.from_pandas(df) 

1310 return cls.from_polars(df, shuffle=shuffle, is_annotated=is_annotated) 

1311 

1312 @classmethod 

1313 def from_polars(cls, df: pl.DataFrame, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame": 

1314 """Create a SpectrumDataFrame from a polars DataFrame. 

1315 

1316 Args: 

1317 df (pl.DataFrame): The polars DataFrame. 

1318 

1319 Returns: 

1320 SpectrumDataFrame: The resulting SpectrumDataFrame. 

1321 """ 

1322 return cls( 

1323 df=df, 

1324 shuffle=shuffle, 

1325 is_annotated=is_annotated, 

1326 ) 

1327 

1328 @classmethod 

1329 def load_ipc(cls, source: str, shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame": 

1330 """Load a SpectrumDataFrame from IPC format. 

1331 

1332 Args: 

1333 source (str): Path to the IPC file. 

1334 

1335 Returns: 

1336 SpectrumDataFrame: The loaded SpectrumDataFrame. 

1337 """ 

1338 df = cls._df_from_ipc(source) 

1339 return cls( 

1340 df, 

1341 is_lazy=False, 

1342 shuffle=shuffle, 

1343 is_annotated=is_annotated, 

1344 ) 

1345 

1346 @staticmethod 

1347 def _df_from_ipc(source: str) -> pl.DataFrame: 

1348 """Load a polars DataFrame from an IPC file. 

1349 

1350 Args: 

1351 source (str): Path to the IPC file. 

1352 

1353 Returns: 

1354 pl.DataFrame: The loaded polars DataFrame. 

1355 """ 

1356 df = pl.read_ipc(source) 

1357 if "modified_sequence" in df.columns: 

1358 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN)) 

1359 return df 

1360 

1361 @staticmethod 

1362 def _df_from_csv(source: str) -> pl.DataFrame: 

1363 """Load a polars DataFrame from a CSV file. 

1364 

1365 Args: 

1366 source (str): Path to the CSV file. 

1367 

1368 Returns: 

1369 pl.DataFrame: The loaded polars DataFrame. 

1370 """ 

1371 df = pl.read_csv(source) 

1372 if "modified_sequence" in df.columns: 

1373 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN)) 

1374 return df 

1375 

1376 @staticmethod 

1377 def _df_from_parquet(source: str) -> pl.DataFrame: 

1378 """Load a polars DataFrame from a parquet file. 

1379 

1380 Args: 

1381 source (str): Path to the parquet file. 

1382 

1383 Returns: 

1384 pl.DataFrame: The loaded polars DataFrame. 

1385 """ 

1386 df = pl.read_parquet(source) 

1387 if "modified_sequence" in df.columns: 

1388 df = df.with_columns(pl.col("modified_sequence").alias(ANNOTATED_COLUMN)) 

1389 return df 

1390 

1391 @classmethod 

1392 def from_matchms(cls, spectra: list[Spectrum], shuffle: bool = False, is_annotated: bool = False) -> "SpectrumDataFrame": 

1393 """Create a SpectrumDataFrame from Matchms spectrum objects. 

1394 

1395 Args: 

1396 spectra (list): List of Matchms spectrum objects. 

1397 shuffle (bool, optional): If True, shuffle the data. Defaults to False. 

1398 is_annotated (bool, optional): If True, treat the spectra as annotated. 

1399 Defaults to False. 

1400 

1401 Returns: 

1402 SpectrumDataFrame: The resulting SpectrumDataFrame. 

1403 

1404 Raises: 

1405 ValueError: If the input parameters are invalid or incompatible. 

1406 """ 

1407 df = SpectrumDataFrame._df_from_matchms(spectra) 

1408 return cls( 

1409 df=df, 

1410 shuffle=shuffle, 

1411 is_annotated=is_annotated, 

1412 ) 

1413 

1414 @staticmethod 

1415 def _parse_scan_number(scan_number: str, index: int) -> int | None: 

1416 """Try parse scan number.""" 

1417 if scan_number.isdigit(): 

1418 return int(scan_number) 

1419 

1420 # Use regex to extract the value after 'scan=' 

1421 match = re.search(r"scan=(\d+)", scan_number) 

1422 if match: 

1423 return int(match.group(1)) 

1424 

1425 # use index if scan number cannot be accessed 

1426 return index 

1427 

1428 @staticmethod 

1429 def _df_from_dict(data: dict[str, Any]) -> pl.DataFrame: 

1430 df = pl.DataFrame( 

1431 { 

1432 "scan_number": pl.Series( 

1433 [SpectrumDataFrame._parse_scan_number(str(x), i) for i, x in enumerate(data["scan_number"])], 

1434 dtype=pl.Int64, 

1435 ), 

1436 ANNOTATED_COLUMN: pl.Series(data["sequence"], dtype=pl.Utf8), 

1437 # Calculate precursor mass 

1438 MSColumns.PRECURSOR_MASS.value: pl.Series( 

1439 np.array(data["precursor_mz"]) * np.array(data["precursor_charge"]) - np.array(data["precursor_charge"]) * PROTON_MASS_AMU, 

1440 dtype=MS_TYPES[MSColumns.PRECURSOR_MASS], 

1441 ), 

1442 MSColumns.PRECURSOR_MZ.value: pl.Series(data["precursor_mz"], dtype=MS_TYPES[MSColumns.PRECURSOR_MZ]), 

1443 MSColumns.PRECURSOR_CHARGE.value: pl.Series(data["precursor_charge"], dtype=MS_TYPES[MSColumns.PRECURSOR_CHARGE]), 

1444 MSColumns.RETENTION_TIME.value: pl.Series(data["retention_time"], dtype=MS_TYPES[MSColumns.RETENTION_TIME]), 

1445 MSColumns.MZ_ARRAY.value: pl.Series(data["mz_array"], dtype=MS_TYPES[MSColumns.MZ_ARRAY]), 

1446 MSColumns.INTENSITY_ARRAY.value: pl.Series(data["intensity_array"], dtype=MS_TYPES[MSColumns.INTENSITY_ARRAY]), 

1447 } 

1448 ) 

1449 return df 

1450 

1451 @staticmethod 

1452 def _df_from_matchms(spectra: list[Spectrum]) -> pl.DataFrame: 

1453 """Load a polars DataFrame from a list of Matchms spectra. 

1454 

1455 Args: 

1456 spectra (list[Spectrum]): List of Matchms spectrum objects. 

1457 

1458 Returns: 

1459 pl.DataFrame: The loaded polars DataFrame. 

1460 """ 

1461 data: dict[str, list[Any]] = { 

1462 "scan_number": [], 

1463 "sequence": [], 

1464 "precursor_mass": [], 

1465 "precursor_mz": [], 

1466 "precursor_charge": [], 

1467 "retention_time": [], 

1468 "mz_array": [], 

1469 "intensity_array": [], 

1470 } 

1471 

1472 for i, spectrum in enumerate(spectra): 

1473 data["scan_number"].append(i) 

1474 data["sequence"].append(spectrum.metadata.get("peptide_sequence", "")) 

1475 data["precursor_mass"].append(spectrum.metadata.get("pepmass", 0.0)) 

1476 data["precursor_mz"].append(spectrum.metadata.get("precursor_mz", 0.0)) 

1477 data["precursor_charge"].append(spectrum.metadata.get("charge", 0)) 

1478 data["retention_time"].append(spectrum.metadata.get("retention_time", 0.0)) 

1479 data["mz_array"].append(spectrum.peaks.mz) 

1480 data["intensity_array"].append(spectrum.peaks.intensities) 

1481 

1482 df = SpectrumDataFrame._df_from_dict(data) 

1483 

1484 return df 

1485 

1486 @staticmethod 

1487 def _df_to_matchms(df: pl.DataFrame) -> list[Spectrum]: 

1488 """Convert a polars DataFrame to a list of Matchms spectra. 

1489 

1490 Args: 

1491 df (pl.DataFrame): The input polars DataFrame. 

1492 

1493 Returns: 

1494 list[Spectrum]: List of Matchms spectrum objects. 

1495 """ 

1496 spectra = [] 

1497 

1498 for row in df.iter_rows(named=True): 

1499 metadata = { 

1500 "peptide_sequence": row[ANNOTATED_COLUMN], 

1501 # "pepmass": row[MSColumns.PRECURSOR_MASS.value], 

1502 "precursor_mz": row[MSColumns.PRECURSOR_MZ.value], 

1503 "charge": row[MSColumns.PRECURSOR_CHARGE.value], 

1504 "retention_time": row[MSColumns.RETENTION_TIME.value], 

1505 } 

1506 

1507 mz_array = np.array(row[MSColumns.MZ_ARRAY.value]) 

1508 intensity_array = np.array(row[MSColumns.INTENSITY_ARRAY.value]) 

1509 

1510 spectrum = Spectrum(mz_array, intensity_array, metadata=metadata) 

1511 spectra.append(spectrum) 

1512 

1513 return spectra 

1514 

1515 def _check_type_spec(self) -> None: 

1516 """Check the data type specifications for the DataFrame columns. 

1517 

1518 This method validates that important columns have the correct data types. 

1519 """ 

1520 # Check expected columns, use ENUM constant for this. 

1521 expected_cols = [ 

1522 c.value 

1523 for c in [ 

1524 MSColumns.MZ_ARRAY, 

1525 MSColumns.INTENSITY_ARRAY, 

1526 MSColumns.PRECURSOR_MZ, 

1527 MSColumns.PRECURSOR_CHARGE, 

1528 ] 

1529 ] 

1530 if self.is_annotated: 

1531 expected_cols.append(ANNOTATED_COLUMN) 

1532 

1533 missing_cols = [] 

1534 if self._is_native: 

1535 # Check all parquet files in parallel 

1536 with ThreadPoolExecutor() as executor: 

1537 # First check for missing columns 

1538 def check_columns(file_path: str) -> list[str]: 

1539 columns = pl.scan_parquet(file_path).collect_schema().keys() 

1540 return [col for col in expected_cols if col not in columns] 

1541 

1542 missing_cols_results = list(executor.map(check_columns, self._file_paths)) 

1543 missing_cols = next((cols for cols in missing_cols_results if cols), []) 

1544 

1545 # If no missing columns and we need to check annotations 

1546 if not missing_cols and self.is_annotated: 

1547 

1548 def check_annotations(file_path: str) -> Any: 

1549 return ( 

1550 pl.scan_parquet(file_path) 

1551 .select(((pl.col(ANNOTATED_COLUMN).is_not_null()) & (pl.col(ANNOTATED_COLUMN) != "")).all()) 

1552 .collect() 

1553 .to_numpy()[0] 

1554 ) 

1555 

1556 has_annotations_results = list(executor.map(check_annotations, self._file_paths)) 

1557 if not all(has_annotations_results): 

1558 raise ValueError(ANNOTATION_ERROR) 

1559 else: 

1560 assert self.df is not None 

1561 # Check only self.df 

1562 missing_cols = [col for col in expected_cols if col not in self.df.columns] 

1563 if not missing_cols and self.is_annotated: 

1564 has_annotations = self.df.select(((pl.col(ANNOTATED_COLUMN).is_not_null()) & (pl.col(ANNOTATED_COLUMN) != "")).all()).to_numpy()[0] 

1565 if not has_annotations: 

1566 raise ValueError(ANNOTATION_ERROR) 

1567 

1568 if not missing_cols: 

1569 # In non-native mode also check the charge column is not all zeros (DIA) 

1570 if self.df is not None: 

1571 if self.df[MSColumns.PRECURSOR_CHARGE.value].sum() == 0: 

1572 logger.warning("The charge column is all zeros. This could indicate a DIA dataset or contains invalid values.") 

1573 

1574 if missing_cols: 

1575 plural_s = "s" if len(missing_cols) > 1 else "" 

1576 missing_col_names = ", ".join(missing_cols) 

1577 raise ValueError(f"Column{plural_s} missing! Missing column{plural_s}: {missing_col_names}") 

1578 

1579 @classmethod 

1580 def concatenate(cls, sdf: list["SpectrumDataFrame"], strict: bool = True) -> "SpectrumDataFrame": 

1581 """Concatenate a list of SpectrumDataFrames. 

1582 

1583 Warning: 

1584 This function loads the entire dataset into memory. For large datasets, 

1585 this may consume a significant amount of RAM. 

1586 

1587 Args: 

1588 df (list[SpectrumDataFrame]): List of SpectrumDataFrames to concatenate. 

1589 strict (bool): Whether to perform strict concatenation. 

1590 

1591 Returns: 

1592 SpectrumDataFrame: The resulting concatenated SpectrumDataFrame. 

1593 """ 

1594 return cls(pl.concat([x.to_polars(return_lazy=False) for x in sdf])) 

1595 

1596 def get_unique_sequences(self) -> set[str]: 

1597 """Retrieve unique peptide sequences from the dataset. 

1598 

1599 Returns: 

1600 set[str]: A set of unique peptide sequences. 

1601 

1602 Raises: 

1603 ValueError: If the dataset is not annotated. 

1604 """ 

1605 if not self.is_annotated: 

1606 raise ValueError("Only annotated datasets have sequences.") 

1607 

1608 if self._is_native: 

1609 sequences = set() 

1610 for fp in self._file_paths: 

1611 df_unique = pl.scan_parquet(fp).filter(self._filter_series_per_file[fp]).select(pl.col(ANNOTATED_COLUMN).unique()).collect() 

1612 sequences.update(set(df_unique[ANNOTATED_COLUMN].to_list())) 

1613 return sequences 

1614 else: 

1615 assert self.df is not None 

1616 return set(self.df[ANNOTATED_COLUMN].unique()) 

1617 

1618 def get_vocabulary(self, tokenize_fn: Callable) -> set[str]: 

1619 """Get the vocabulary of unique residues from peptide sequences. 

1620 

1621 Args: 

1622 tokenize_fn (Callable): Function to tokenize peptide sequences. 

1623 

1624 Returns: 

1625 set[str]: A set of unique residues. 

1626 

1627 Raises: 

1628 ValueError: If the dataset is not annotated. 

1629 """ 

1630 if not self.is_annotated: 

1631 raise ValueError("Only annotated datasets have residue vocabularies.") 

1632 

1633 sequences = self.get_unique_sequences() 

1634 residues = set() 

1635 for x in sequences: 

1636 residues.update(set(tokenize_fn(x))) 

1637 return residues 

1638 

1639 def validate_precursor_mass(self, metrics: Metrics, tolerance: float = 50) -> int: 

1640 """Validate precursor mz matching the annotations. 

1641 

1642 Args: 

1643 metrics (Metrics): InstaNovo metrics class for calculating sequence mass. 

1644 tolerance (float): Tolerance to match precursor mass in ppm. 

1645 

1646 Returns: 

1647 int: Number of precursor matches 

1648 

1649 Raises: 

1650 ValueError: If none of the sequences match the precursor mz. 

1651 ValueError: If SpectrumDataFrame is not annotated. 

1652 """ 

1653 if not self.is_annotated: 

1654 raise ValueError("Cannot verify precursor mass without annotations.") 

1655 

1656 if self._is_native: 

1657 num_matches_precursor = 0 

1658 for fp in self._file_paths: 

1659 result = ( 

1660 pl.scan_parquet(fp) 

1661 .filter(self._filter_series_per_file[fp]) 

1662 .select( 

1663 [ 

1664 pl.col(ANNOTATED_COLUMN), 

1665 pl.col(MSColumns.PRECURSOR_MZ.value), 

1666 pl.col(MSColumns.PRECURSOR_CHARGE.value), 

1667 ] 

1668 ) 

1669 .with_columns( 

1670 [ 

1671 pl.struct( 

1672 [ 

1673 ANNOTATED_COLUMN, 

1674 pl.col(MSColumns.PRECURSOR_MZ.value), 

1675 pl.col(MSColumns.PRECURSOR_CHARGE.value), 

1676 ] 

1677 ) 

1678 .map_elements( 

1679 lambda x: metrics.matches_precursor( 

1680 x[ANNOTATED_COLUMN], 

1681 x[MSColumns.PRECURSOR_MZ.value], 

1682 x[MSColumns.PRECURSOR_CHARGE.value], 

1683 prec_tol=tolerance, 

1684 )[0], 

1685 return_dtype=bool, 

1686 ) 

1687 .alias("precursor_match") 

1688 ] 

1689 ) 

1690 .select(pl.col("precursor_match").sum().alias("num_matches_precursor")) 

1691 ) 

1692 num_matches_precursor += result.collect()["num_matches_precursor"][0] 

1693 else: 

1694 assert self.df is not None 

1695 result = ( 

1696 self.df.select( 

1697 [ 

1698 pl.col(ANNOTATED_COLUMN), 

1699 pl.col(MSColumns.PRECURSOR_MZ.value), 

1700 pl.col(MSColumns.PRECURSOR_CHARGE.value), 

1701 ] 

1702 ) 

1703 .with_columns( 

1704 [ 

1705 pl.struct( 

1706 pl.col(ANNOTATED_COLUMN), 

1707 pl.col(MSColumns.PRECURSOR_MZ.value), 

1708 pl.col(MSColumns.PRECURSOR_CHARGE.value), 

1709 ) 

1710 .map_elements( 

1711 lambda x: metrics.matches_precursor( 

1712 x[ANNOTATED_COLUMN], 

1713 x[MSColumns.PRECURSOR_MZ.value], 

1714 x[MSColumns.PRECURSOR_CHARGE.value], 

1715 prec_tol=tolerance, 

1716 )[0], 

1717 return_dtype=bool, 

1718 ) 

1719 .alias("precursor_match") 

1720 ] 

1721 ) 

1722 .select(pl.col("precursor_match").sum().alias("num_matches_precursor")) 

1723 ) 

1724 num_matches_precursor = result["num_matches_precursor"][0] 

1725 

1726 if num_matches_precursor == 0: 

1727 raise ValueError("None of the sequence labels in the dataset match the precursor mz. Check sequences and residue set for errors.") 

1728 elif num_matches_precursor < len(self): 

1729 logger.warning( 

1730 f"{len(self) - num_matches_precursor:,d} " 

1731 f"({(1 - num_matches_precursor / len(self)) * 100:.2f}%) of the sequence labels " 

1732 f"do not match the precursor mz to {tolerance}ppm." 

1733 ) 

1734 

1735 return num_matches_precursor 

1736 

1737 # def filter(self, *args, **kwargs) -> "SpectrumDataFrame": 

1738 # return self.df.filter(*args, **kwargs) 

1739 

1740 def sample_subset(self, fraction: float, seed: int | None = None) -> None: 

1741 """Sample a subset of the dataset. 

1742 

1743 Args: 

1744 fraction (float): Fraction of the dataset to sample. 

1745 seed (int): Random seed for reproducibility. 

1746 """ 

1747 if fraction >= 1: 

1748 return 

1749 if self._is_native: 

1750 for fp in self._file_paths: 

1751 if seed: 

1752 np.random.seed(seed) 

1753 # TODO: variable "filter" is shadowing a python builtin 

1754 filter = self._filter_series_per_file[fp].to_numpy() # noqa 

1755 filter[filter] = np.random.choice([True, False], size=filter.sum(), p=[fraction, 1 - fraction]) 

1756 self._filter_series_per_file[fp] &= pl.Series(filter) 

1757 if not self._shuffle: 

1758 self._update_file_indices() 

1759 self._reset_current_file() 

1760 else: 

1761 assert self.df is not None 

1762 self.df = self.df.sample(fraction=fraction, seed=seed) 

1763 

1764 def validate_data(self) -> bool: 

1765 """Validate the integrity of the dataset. 

1766 

1767 Returns: 

1768 bool: True if the data is valid, False otherwise. 

1769 """ 

1770 raise NotImplementedError() 

1771 

1772 def __del__(self) -> None: 

1773 """Clean up the resources when the object is destroyed. 

1774 

1775 This includes shutting down the thread pool executor and removing temporary files. 

1776 """ 

1777 if self.executor: 

1778 self.executor.shutdown(wait=True) 

1779 if self._temp_directory is not None and os.path.exists(self._temp_directory): 

1780 shutil.rmtree(self._temp_directory) 

1781 

1782 def _strip_shape_info(self, s: str) -> str: 

1783 """Adjust the string representation of a SpectrumDataFrame for lazy loading. 

1784 

1785 Strips shape information that would only correspond to the first temporary file. 

1786 

1787 Parameters: 

1788 s (str): The string representation of the SpectrumDataFrame. 

1789 

1790 Returns: 

1791 str: The adjusted string representation of the SpectrumDataFrame. 

1792 """ 

1793 # Pattern to match rows and columns lines 

1794 pattern = re.compile(r"(Rows:\s*\d+\s*Columns:\s*\d+)") 

1795 # Replace the shape with a placeholder 

1796 return pattern.sub("Shape: unknown in lazy loading mode.", s) 

1797 

1798 @staticmethod 

1799 def _truncate_list_repr(s: str, max_items: int = 3) -> str: 

1800 """Find and truncate long lists within the SpectrumDataFrame string preview. 

1801 

1802 Args: 

1803 s (str): String representation of SpectrumDataFrame. 

1804 max_items (int): Maximum number of list items to display at the list's head and tail. 

1805 

1806 Returns: 

1807 str: SpectrumDataFrame string representation with truncated list items, if necessary. 

1808 """ 

1809 

1810 def process_list(match: re.Match) -> str: 

1811 # Extract the list content 

1812 list_content = match.group(1) 

1813 # Convert to a Python list 

1814 values = list(map(str.strip, list_content.split(","))) 

1815 # Truncate if necessary 

1816 if len(values) > 2 * max_items: 

1817 truncated = values[:max_items] + ["..."] + values[-max_items:] 

1818 else: 

1819 truncated = values 

1820 # Rebuild the list as a string 

1821 return f"[{', '.join(truncated)}]" 

1822 

1823 # Regex to find lists in the string 

1824 list_pattern = re.compile(r"\[([^\]]*?)\]") 

1825 # Apply truncation to all matches 

1826 return list_pattern.sub(process_list, s) 

1827 

1828 def _display_string_preview(self, df: Union[pl.DataFrame, "SpectrumDataFrame"]) -> str: 

1829 """String preview of SpectrumDataFrame, truncating long list items. 

1830 

1831 Args: 

1832 df (Union(pl.DataFrame, "SpectrumDataFrame")): SpectrumDataFrame or Polars form of 

1833 SpectrumDataFrame for string representation. 

1834 

1835 Returns: 

1836 str: String representation of SpectrumDataFrame. 

1837 """ 

1838 if type(df) is not pl.DataFrame: 

1839 df = df.to_polars(return_lazy=False) 

1840 

1841 preview = df.glimpse( 

1842 max_items_per_column=self.max_items_per_column, 

1843 max_colname_length=self.max_colname_length, 

1844 return_as_string=True, 

1845 ) 

1846 

1847 if self.is_lazy: 

1848 preview = self._strip_shape_info(preview) 

1849 

1850 preview = self._truncate_list_repr(preview) 

1851 

1852 return preview 

1853 

1854 def __str__(self) -> str: 

1855 """A user-friendly string representation of the SpectrumDataFrame object. 

1856 

1857 Returns: 

1858 str: String representation of SpectrumDataFrame. 

1859 """ 

1860 # Metadata summary 

1861 output = f"<SpectrumDataFrame | Lazy Loaded: {'Yes' if self.is_lazy else 'No'}>\n" 

1862 

1863 if not self.is_lazy and self._file_paths == []: 

1864 # Eager loading and non-parquet input case: the df is already loaded in memory 

1865 output += self._display_string_preview(self.df) 

1866 

1867 else: 

1868 # Lazy loading and parquet cases: only load the first temp file. 

1869 temp_sdf = self.load(self._file_paths[0]) 

1870 output += self._display_string_preview(temp_sdf) 

1871 

1872 return output 

1873 

1874 def __repr__(self) -> str: 

1875 """An unambiguous string representation of the SpectrumDataFrame object. 

1876 

1877 This method is intended to provide enough detail to reconstruct the SpectrumDataFrame 

1878 object (if possible) or for debugging purposes. It includes all relevant attributes 

1879 in a nested format. 

1880 

1881 Returns: 

1882 str: String representation of SpectrumDataFrame. 

1883 """ 

1884 

1885 def adjust_indentation(input_string: str) -> str: 

1886 """Adjusts the indentation of a multiline string based on the number of leading tabs. 

1887 

1888 Args: 

1889 input_string (str): The input string with potential leading tabs. 

1890 

1891 Returns: 

1892 str: A string with consistent indentation after each newline. 

1893 """ 

1894 # Find the leading tabs at the beginning of the string 

1895 match = re.match(r"^(\t*)", input_string) 

1896 if match: 

1897 leading_tabs = match.group(1) # Capture the leading tabs 

1898 else: 

1899 leading_tabs = "" 

1900 

1901 # Add the leading tabs after every newline 

1902 adjusted_string = input_string.replace("\n", "\n" + leading_tabs) 

1903 

1904 return adjusted_string 

1905 

1906 def pretty(d: dict, indent: int = 1) -> str: 

1907 """Recursively formats a dictionary into a pretty indented string. 

1908 

1909 Args: 

1910 d (dict): The dictionary to format. 

1911 indent (int): The current indentation level. 

1912 

1913 Returns: 

1914 str: A pretty-formatted string representation of the dictionary. 

1915 """ 

1916 lines = [] 

1917 for key, value in d.items(): 

1918 lines.append("\t" * indent + str(key) + " =") # Add the key 

1919 if isinstance(value, dict): 

1920 # Recursively format nested dictionary 

1921 lines.append(pretty(value, indent + 1)) 

1922 else: 

1923 # Add the value 

1924 next_rep = "\t" * (indent + 1) + str(value) 

1925 if isinstance(value, pl.DataFrame) or isinstance(value, pl.Series): 

1926 # Find how many tabs are at the front of the string, 

1927 # and add that many tabs after every newline entry. 

1928 next_rep = adjust_indentation(next_rep) 

1929 

1930 lines.append(next_rep) 

1931 

1932 return "\n".join(lines) 

1933 

1934 class_name = self.__class__.__name__ 

1935 attributes = pretty(vars(self)) 

1936 rep = f"{class_name}(\n{attributes}\n)" 

1937 

1938 return rep 

1939 

1940 

1941def _format_time(seconds: float) -> str: 

1942 seconds = int(seconds) 

1943 return f"{seconds // 3600:02d}:{(seconds % 3600) // 60:02d}:{seconds % 60:02d}"