Coverage for instanovo/scripts/convert_to_sdf.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1# /// script 

2# requires-python = ">=3.10" 

3# dependencies = [ 

4# "instanovo", 

5# "typer", 

6# ] 

7# /// 

8from __future__ import annotations 

9 

10import logging 

11from enum import Enum 

12from pathlib import Path 

13from typing import Optional 

14 

15import typer 

16from typing_extensions import Annotated 

17 

18from instanovo.__init__ import console 

19from instanovo.utils.colorlogging import ColorLog 

20 

21logger = ColorLog(console, __name__).logger 

22 

23app = typer.Typer() 

24 

25 

26class Partition(str, Enum): 

27 """Partition of saved dataset.""" 

28 

29 train = "train" 

30 valid = "valid" 

31 test = "test" 

32 

33 

34@app.command() 

35def convert( 

36 source: Annotated[str, typer.Argument(help="Source file(s)")], 

37 target: Annotated[ 

38 Path, 

39 typer.Argument(exists=True, file_okay=False, dir_okay=True, help="Target folder to save data shards"), 

40 ], 

41 name: Annotated[Optional[str], typer.Option(help="Name of saved dataset")], 

42 partition: Annotated[Partition, typer.Option(help="Partition of saved dataset")], 

43 max_charge: Annotated[int, typer.Option(help="Maximum charge to filter out")] = 10, 

44 shard_size: Annotated[int, typer.Option(help="Length of saved data shards")] = 1_000_000, 

45 is_annotated: Annotated[bool, typer.Option("--is-annotated", help="whether dataset is annotated")] = False, 

46 add_spectrum_id: Annotated[bool, typer.Option("--add-spectrum-id", help="Add spectrum id column")] = False, 

47) -> None: 

48 """Convert data to SpectrumDataFrame and save as *.parquet file(s).""" 

49 from instanovo.utils.data_handler import SpectrumDataFrame 

50 

51 logging.basicConfig(level=logging.INFO) 

52 

53 logger.info(f"Loading {source}") 

54 sdf = SpectrumDataFrame.load( 

55 source, 

56 is_annotated=is_annotated, 

57 name=name, 

58 partition=partition.value, 

59 max_shard_size=shard_size, 

60 lazy=True, 

61 add_spectrum_id=add_spectrum_id, 

62 ) 

63 logger.info(f"Loaded {len(sdf):,d} rows") 

64 

65 logger.info(f"Filtering max_charge <= {max_charge}") 

66 sdf.filter_rows(lambda row: row["precursor_charge"] <= max_charge) 

67 

68 logger.info(f"Saving {len(sdf):,d} rows to {target}") 

69 sdf.save( 

70 target, 

71 name=name, 

72 partition=partition.value, 

73 max_shard_size=shard_size, 

74 ) 

75 

76 logger.info("Saving complete.") 

77 del sdf 

78 

79 

80if __name__ == "__main__": 

81 app()