Coverage for instanovo/diffusion/cli.py: 72%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import glob
2from pathlib import Path
3from typing import List, Optional
5import typer
6from typing_extensions import Annotated
8from instanovo.__init__ import console
9from instanovo.constants import (
10 DEFAULT_INFERENCE_CONFIG_NAME,
11 DEFAULT_INFERENCE_CONFIG_PATH,
12 DEFAULT_TRAIN_CONFIG_PATH,
13)
14from instanovo.utils.cli_utils import compose_config
15from instanovo.utils.colorlogging import ColorLog
17logger = ColorLog(console, __name__).logger
19cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False)
22@cli.command("train")
23def diffusion_train(
24 config_path: Annotated[
25 Optional[str],
26 typer.Option(
27 "--config-path",
28 "-cp",
29 help="Relative path to config directory.",
30 ),
31 ] = None,
32 config_name: Annotated[
33 Optional[str],
34 typer.Option(
35 "--config-name",
36 "-cn",
37 help="The name of the config (usually the file name without the .yaml extension).",
38 ),
39 ] = None,
40 overrides: Optional[List[str]] = typer.Argument(None, hidden=True),
41) -> None:
42 """Train the InstaNovo+ model."""
43 logger.info("Initializing InstaNovo+ training.")
44 from instanovo.diffusion.train import DiffusionTrainer
46 if config_path is None:
47 config_path = DEFAULT_TRAIN_CONFIG_PATH
48 if config_name is None:
49 config_name = "instanovoplus"
51 config = compose_config(
52 config_path=config_path,
53 config_name=config_name,
54 overrides=overrides,
55 )
57 logger.info("Initializing diffusion training.")
58 trainer = DiffusionTrainer(config)
59 trainer.train()
62@cli.command("predict")
63def diffusion_predict(
64 data_path: Annotated[
65 Optional[str],
66 typer.Option(
67 "--data-path",
68 "-d",
69 help="Path to input data file",
70 ),
71 ] = None,
72 output_path: Annotated[
73 Optional[Path],
74 typer.Option(
75 "--output-path",
76 "-o",
77 help="Path to output file.",
78 exists=False,
79 file_okay=True,
80 dir_okay=False,
81 ),
82 ] = None,
83 instanovo_plus_model: Annotated[
84 Optional[str],
85 typer.Option(
86 "--instanovo-plus-model",
87 "-p",
88 help="Either a model ID or a path to an Instanovo+ checkpoint file (.ckpt format)",
89 # help="Either a model ID (currently supported: "
90 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())})"""
91 # " or a path to an Instanovo+ checkpoint file (.ckpt format)",
92 ),
93 ] = None,
94 denovo: Annotated[
95 Optional[bool],
96 typer.Option(
97 "--denovo/--evaluation",
98 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?",
99 ),
100 ] = None,
101 refine: Annotated[
102 Optional[bool],
103 typer.Option(
104 "--with-refinement/--no-refinement",
105 help="Refine the predictions of the transformer-based InstaNovo model with the diffusion-based InstaNovo+ model?",
106 ),
107 ] = None,
108 config_path: Annotated[
109 Optional[str],
110 typer.Option(
111 "--config-path",
112 "-cp",
113 help="Relative path to config directory.",
114 ),
115 ] = None,
116 config_name: Annotated[
117 Optional[str],
118 typer.Option(
119 "--config-name",
120 "-cn",
121 help="The name of the config (usually the file name without the .yaml extension).",
122 ),
123 ] = None,
124 overrides: Optional[List[str]] = typer.Argument(None, hidden=True),
125) -> None:
126 """Run predictions with InstaNovo+."""
127 logger.info("Initializing InstaNovo+ inference.")
128 # Defer imports to improve cli performance
129 from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus
130 from instanovo.diffusion.predict import DiffusionPredictor
132 if config_path is None:
133 config_path = DEFAULT_INFERENCE_CONFIG_PATH
134 if config_name is None:
135 config_name = DEFAULT_INFERENCE_CONFIG_NAME
137 config = compose_config(
138 config_path=config_path,
139 config_name=config_name,
140 overrides=overrides,
141 )
143 # Check config inputs
144 if data_path is not None:
145 if "*" in data_path or "?" in data_path or "[" in data_path:
146 # Glob notation: path/to/data/*.parquet
147 if not glob.glob(data_path):
148 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).")
149 config.data_path = str(data_path)
151 if not config.get("data_path", None) and data_path is None:
152 raise ValueError(
153 "Expected 'data_path' but found None. Please specify it in the "
154 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
155 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a "
156 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`."
157 )
159 if denovo is not None:
160 config.denovo = denovo
161 if refine is not None:
162 config.refine = refine
164 if output_path is not None:
165 if output_path.exists():
166 logger.info(f"Output path '{output_path}' already exists and will be overwritten.")
167 output_path.parent.mkdir(parents=True, exist_ok=True)
168 config.output_path = str(output_path)
169 if config.output_path and not Path(config.output_path).parent.exists():
170 Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
171 if config.get("output_path", None) is None and (config.get("denovo", False)):
172 raise ValueError(
173 "Expected 'output_path' but found None in denovo mode. Please specify it "
174 "in the `config/inference/<your_config>.yaml` configuration file or with the cli flag "
175 "`--output-path=path/to/output_file`."
176 )
178 if instanovo_plus_model is not None:
179 if Path(instanovo_plus_model).is_dir():
180 required_files = ["*.ckpt", "*.yaml", "*.pt"]
181 missing_files = [ext for ext in required_files if not list(Path(instanovo_plus_model).glob(ext))]
182 if missing_files:
183 raise ValueError(f"The directory '{instanovo_plus_model}' is missing the following required file(s): {', '.join(missing_files)}.")
184 elif (
185 not Path(instanovo_plus_model).is_file()
186 and not instanovo_plus_model.startswith("s3://")
187 and instanovo_plus_model not in InstaNovoPlus.get_pretrained()
188 ):
189 raise ValueError(
190 f"InstaNovo+ model ID '{instanovo_plus_model}' is not supported. Currently "
191 "supported value(s): "
192 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}"""
193 )
194 config.instanovo_plus_model = instanovo_plus_model
196 if not config.get("instanovo_plus_model", None):
197 raise ValueError(
198 "Expected 'instanovo_plus_model' but found None. Please specify it in the "
199 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
200 "`instanovo diffusion predict --instanovo-plus-model=path/to/model_dir`."
201 )
203 if (
204 config.get("refine", False)
205 and config.get("refinement_path", None) == config.get("output_path", None)
206 and config.get("refinement_path", None) is not None
207 ):
208 raise ValueError("The 'refinement_path' should be different from the 'output_path' to avoid overwriting the original predictions.")
210 predictor = DiffusionPredictor(config)
211 predictor.predict()