Coverage for instanovo/scripts/convert_to_sdf.py: 100%
28 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1# /// script
2# requires-python = ">=3.10"
3# dependencies = [
4# "instanovo",
5# "typer",
6# ]
7# ///
8from __future__ import annotations
10import logging
11from enum import Enum
12from pathlib import Path
13from typing import Optional
15import typer
16from typing_extensions import Annotated
18from instanovo.__init__ import console
19from instanovo.utils.colorlogging import ColorLog
21logger = ColorLog(console, __name__).logger
23app = typer.Typer()
26class Partition(str, Enum):
27 """Partition of saved dataset."""
29 train = "train"
30 valid = "valid"
31 test = "test"
34@app.command()
35def convert(
36 source: Annotated[str, typer.Argument(help="Source file(s)")],
37 target: Annotated[
38 Path,
39 typer.Argument(exists=True, file_okay=False, dir_okay=True, help="Target folder to save data shards"),
40 ],
41 name: Annotated[Optional[str], typer.Option(help="Name of saved dataset")],
42 partition: Annotated[Partition, typer.Option(help="Partition of saved dataset")],
43 max_charge: Annotated[int, typer.Option(help="Maximum charge to filter out")] = 10,
44 shard_size: Annotated[int, typer.Option(help="Length of saved data shards")] = 1_000_000,
45 is_annotated: Annotated[bool, typer.Option("--is-annotated", help="whether dataset is annotated")] = False,
46 add_spectrum_id: Annotated[bool, typer.Option("--add-spectrum-id", help="Add spectrum id column")] = False,
47) -> None:
48 """Convert data to SpectrumDataFrame and save as *.parquet file(s)."""
49 from instanovo.utils.data_handler import SpectrumDataFrame
51 logging.basicConfig(level=logging.INFO)
53 logger.info(f"Loading {source}")
54 sdf = SpectrumDataFrame.load(
55 source,
56 is_annotated=is_annotated,
57 name=name,
58 partition=partition.value,
59 max_shard_size=shard_size,
60 lazy=True,
61 add_spectrum_id=add_spectrum_id,
62 )
63 logger.info(f"Loaded {len(sdf):,d} rows")
65 logger.info(f"Filtering max_charge <= {max_charge}")
66 sdf.filter_rows(lambda row: row["precursor_charge"] <= max_charge)
68 logger.info(f"Saving {len(sdf):,d} rows to {target}")
69 sdf.save(
70 target,
71 name=name,
72 partition=partition.value,
73 max_shard_size=shard_size,
74 )
76 logger.info("Saving complete.")
77 del sdf
80if __name__ == "__main__":
81 app()