Coverage for instanovo/transformer/cli.py: 79%
61 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import glob
2from pathlib import Path
3from typing import List, Optional
5import typer
6from omegaconf import DictConfig
7from typing_extensions import Annotated
9from instanovo.__init__ import console
10from instanovo.constants import (
11 DEFAULT_INFERENCE_CONFIG_NAME,
12 DEFAULT_INFERENCE_CONFIG_PATH,
13 DEFAULT_TRAIN_CONFIG_PATH,
14)
15from instanovo.utils.cli_utils import compose_config
16from instanovo.utils.colorlogging import ColorLog
18logger = ColorLog(console, __name__).logger
20cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False)
23@cli.command("train")
24def transformer_train(
25 config_path: Annotated[
26 Optional[str],
27 typer.Option(
28 "--config-path",
29 "-cp",
30 help="Relative path to config directory.",
31 ),
32 ] = None,
33 config_name: Annotated[
34 Optional[str],
35 typer.Option(
36 "--config-name",
37 "-cn",
38 help="The name of the config (usually the file name without the .yaml extension).",
39 ),
40 ] = None,
41 overrides: Optional[List[str]] = typer.Argument(None, hidden=True),
42) -> None:
43 """Train the InstaNovo model."""
44 logger.info("Initializing InstaNovo training.")
46 # Defer imports to improve cli performance
47 from instanovo.transformer.train import TransformerTrainer
49 if config_path is None:
50 config_path = DEFAULT_TRAIN_CONFIG_PATH
51 if config_name is None:
52 config_name = "instanovo"
54 config = compose_config(
55 config_path=config_path,
56 config_name=config_name,
57 overrides=overrides,
58 )
60 logger.info("Initializing InstaNovo training.")
61 trainer = TransformerTrainer(config)
62 trainer.train()
65@cli.command("predict")
66def transformer_predict(
67 data_path: Annotated[
68 Optional[str],
69 typer.Option(
70 "--data-path",
71 "-d",
72 help="Path to input data file",
73 ),
74 ] = None,
75 output_path: Annotated[
76 Optional[Path],
77 typer.Option(
78 "--output-path",
79 "-o",
80 help="Path to output file.",
81 exists=False,
82 file_okay=True,
83 dir_okay=False,
84 ),
85 ] = None,
86 instanovo_model: Annotated[
87 Optional[str],
88 typer.Option(
89 "--instanovo-model",
90 "-i",
91 help=(
92 "Either a model ID or a path to an Instanovo checkpoint file (.ckpt format)."
93 # Removed: expensive in in CLI, TODO: explore re-adding later
94 # "Either a model ID (currently supported: "
95 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())})"""
96 # " or a path to an Instanovo checkpoint file (.ckpt format)."
97 ),
98 ),
99 ] = None,
100 denovo: Annotated[
101 Optional[bool],
102 typer.Option(
103 "--denovo/--evaluation",
104 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?",
105 ),
106 ] = None,
107 config_path: Annotated[
108 Optional[str],
109 typer.Option(
110 "--config-path",
111 "-cp",
112 help="Relative path to config directory.",
113 ),
114 ] = None,
115 config_name: Annotated[
116 Optional[str],
117 typer.Option(
118 "--config-name",
119 "-cn",
120 help="The name of the config (usually the file name without the .yaml extension).",
121 ),
122 ] = None,
123 overrides: Optional[List[str]] = typer.Argument(None, hidden=True),
124) -> DictConfig:
125 """Run predictions with InstaNovo."""
126 # Compose config with overrides
127 logger.info("Initializing InstaNovo inference.")
129 # Defer imports to improve cli performance
130 from instanovo.transformer.model import InstaNovo
131 from instanovo.transformer.predict import TransformerPredictor
133 if config_path is None:
134 config_path = DEFAULT_INFERENCE_CONFIG_PATH
135 if config_name is None:
136 config_name = DEFAULT_INFERENCE_CONFIG_NAME
138 config = compose_config(
139 config_path=config_path,
140 config_name=config_name,
141 overrides=overrides,
142 )
144 # Check config inputs
145 if data_path is not None:
146 if "*" in data_path or "?" in data_path or "[" in data_path:
147 # Glob notation: path/to/data/*.parquet
148 if not glob.glob(data_path):
149 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).")
150 config.data_path = str(data_path)
152 if not config.get("data_path", None) and data_path is None:
153 raise ValueError(
154 "Expected 'data_path' but found None. Please specify it in the "
155 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
156 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a "
157 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`."
158 )
160 if denovo is not None:
161 # Don't compute metrics in denovo mode
162 config.denovo = denovo
164 if output_path is not None:
165 if output_path.exists():
166 logger.info(f"Output path '{output_path}' already exists and will be overwritten.")
167 output_path.parent.mkdir(parents=True, exist_ok=True)
168 config.output_path = str(output_path)
169 if config.get("output_path", None) is None and config.get("denovo", False):
170 raise ValueError(
171 "Expected 'output_path' but found None in denovo mode. Please specify it in the "
172 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
173 "`--output-path=path/to/output_file`."
174 )
176 if instanovo_model is not None:
177 if Path(instanovo_model).is_file() and Path(instanovo_model).suffix != ".ckpt":
178 raise ValueError(f"Checkpoint file '{instanovo_model}' should end with extension '.ckpt'.")
179 if not Path(instanovo_model).is_file() and not instanovo_model.startswith("s3://") and instanovo_model not in InstaNovo.get_pretrained():
180 raise ValueError(
181 f"InstaNovo model ID '{instanovo_model}' is not supported. "
182 "Currently supported value(s): "
183 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}"""
184 )
185 config.instanovo_model = instanovo_model
187 if not config.get("instanovo_model", None):
188 raise ValueError(
189 "Expected 'instanovo_model' but found None. Please specify it in the "
190 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
191 "`instanovo transformer predict --instanovo_model=path/to/model.ckpt`."
192 )
194 logger.info("Initializing InstaNovo inference.")
195 predictor = TransformerPredictor(config)
196 predictor.predict()