Coverage for instanovo/cli.py: 83%
87 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import glob
2import importlib.metadata as metadata
3from pathlib import Path
4from typing import List, Optional
6import typer
7from rich.table import Table
8from typing_extensions import Annotated
10from instanovo import __version__
11from instanovo.__init__ import console
12from instanovo.constants import DEFAULT_INFERENCE_CONFIG_NAME, DEFAULT_INFERENCE_CONFIG_PATH
13from instanovo.diffusion.cli import cli as diffusion_cli
14from instanovo.scripts.convert_to_sdf import app as convert_to_sdf_app
15from instanovo.transformer.cli import cli as transformer_cli
16from instanovo.utils.cli_utils import compose_config
17from instanovo.utils.colorlogging import ColorLog
19logger = ColorLog(console, __name__).logger
21combined_cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False)
22combined_cli.add_typer(
23 transformer_cli,
24 name="transformer",
25 help="Run predictions or train with only the transformer-based InstaNovo model.",
26)
27combined_cli.add_typer(
28 diffusion_cli,
29 name="diffusion",
30 help="Run predictions or train with only the diffusion-based InstaNovo+ model.",
31)
32combined_cli.add_typer(convert_to_sdf_app)
35@combined_cli.callback(invoke_without_command=True)
36def main(ctx: typer.Context) -> None:
37 """Run predictions with InstaNovo and optionally with InstaNovo+.
39 First with the transformer-based InstaNovo model and then optionally refine
40 them with the diffusion based InstaNovo+ model.
41 """
42 # If you just run `instanovo` on the command line, show the help
43 if ctx.invoked_subcommand is None:
44 typer.echo(ctx.get_help())
47@combined_cli.command()
48def predict(
49 data_path: Annotated[
50 Optional[str],
51 typer.Option(
52 "--data-path",
53 "-d",
54 help="Path to input data file",
55 ),
56 ] = None,
57 output_path: Annotated[
58 Optional[Path],
59 typer.Option(
60 "--output-path",
61 "-o",
62 help="Path to output file.",
63 exists=False,
64 file_okay=True,
65 dir_okay=False,
66 ),
67 ] = None,
68 instanovo_model: Annotated[
69 Optional[str],
70 typer.Option(
71 "--instanovo-model",
72 "-i",
73 help=(
74 "Either a model ID or a path to an Instanovo checkpoint file (.ckpt format)"
75 # "Either a model ID (currently supported: "
76 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}) """
77 # "or a path to an Instanovo checkpoint file (.ckpt format)."
78 ),
79 ),
80 ] = None,
81 instanovo_plus_model: Annotated[
82 Optional[str],
83 typer.Option(
84 "--instanovo-plus-model",
85 "-p",
86 help=(
87 "Either a model ID or a path to an Instanovo+ checkpoint file (.ckpt format)"
88 # "Either a model ID (currently supported: "
89 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}) """
90 # "or a path to an Instanovo+ checkpoint file (.ckpt format)"
91 ),
92 ),
93 ] = None,
94 denovo: Annotated[
95 Optional[bool],
96 typer.Option(
97 "--denovo/--evaluation",
98 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?",
99 ),
100 ] = None,
101 refine: Annotated[
102 Optional[bool],
103 typer.Option(
104 "--with-refinement/--no-refinement",
105 help=("Refine the predictions of the transformer-based InstaNovo model with the diffusion-based InstaNovo+ model?"),
106 ),
107 ] = True,
108 config_path: Annotated[
109 Optional[str],
110 typer.Option(
111 "--config-path",
112 "-cp",
113 help="Relative path to config directory.",
114 ),
115 ] = None,
116 config_name: Annotated[
117 Optional[str],
118 typer.Option(
119 "--config-name",
120 "-cn",
121 help="The name of the config (usually the file name without the .yaml extension).",
122 ),
123 ] = None,
124 overrides: Optional[List[str]] = typer.Argument(None, hidden=True),
125) -> None:
126 """Run predictions with InstaNovo and optionally refine with InstaNovo+.
128 First with the transformer-based InstaNovo model and then optionally refine
129 them with the diffusion based InstaNovo+ model.
130 """
131 # Defer imports to improve cli performance
132 logger.info("Initializing InstaNovo inference.")
133 from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus
134 from instanovo.diffusion.predict import CombinedPredictor
135 from instanovo.transformer.model import InstaNovo
137 if config_path is None:
138 config_path = DEFAULT_INFERENCE_CONFIG_PATH
139 if config_name is None:
140 config_name = DEFAULT_INFERENCE_CONFIG_NAME
142 config = compose_config(
143 config_path=config_path,
144 config_name=config_name,
145 overrides=overrides,
146 )
148 # Check config inputs
149 if data_path is not None:
150 if "*" in data_path or "?" in data_path or "[" in data_path:
151 # Glob notation: path/to/data/*.parquet
152 if not glob.glob(data_path):
153 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).")
154 config.data_path = str(data_path)
156 if not config.get("data_path", None) and data_path is None:
157 raise ValueError(
158 "Expected 'data_path' but found None. Please specify it in the "
159 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
160 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a "
161 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`."
162 )
164 if denovo is not None:
165 config.denovo = denovo
166 if refine is not None:
167 config.refine = refine
169 if output_path is not None:
170 if output_path.exists():
171 logger.info(f"Output path '{output_path}' already exists and will be overwritten.")
172 output_path.parent.mkdir(parents=True, exist_ok=True)
173 config.output_path = str(output_path)
174 if config.output_path and not Path(config.output_path).parent.exists():
175 Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
176 if config.get("output_path", None) is None and (config.get("denovo", False)):
177 raise ValueError(
178 "Expected 'output_path' but found None in denovo mode. Please specify it "
179 "in the `config/inference/<your_config>.yaml` configuration file or with the cli flag "
180 "`--output-path=path/to/output_file`."
181 )
183 if instanovo_model is not None:
184 if Path(instanovo_model).is_file() and Path(instanovo_model).suffix != ".ckpt":
185 raise ValueError(f"Checkpoint file '{instanovo_model}' should end with extension '.ckpt'.")
186 if not Path(instanovo_model).is_file() and not instanovo_model.startswith("s3://") and instanovo_model not in InstaNovo.get_pretrained():
187 raise ValueError(
188 f"InstaNovo model ID '{instanovo_model}' is not supported. "
189 "Currently supported value(s): "
190 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}"""
191 )
192 config.instanovo_model = instanovo_model
194 if not config.get("instanovo_model", None):
195 raise ValueError(
196 "Expected 'instanovo_model' but found None. Please specify it in the "
197 "`config/inference/<your_config>.yaml` configuration file or with the cli flag "
198 "`instanovo predict --instanovo_model=path/to/model.ckpt`."
199 )
201 if config.refine:
202 if instanovo_plus_model is not None:
203 if Path(instanovo_plus_model).is_dir():
204 required_files = ["*.ckpt", "*.yaml", "*.pt"]
205 missing_files = [ext for ext in required_files if not list(Path(instanovo_plus_model).glob(ext))]
206 if missing_files:
207 raise ValueError(f"The directory '{instanovo_plus_model}' is missing the following required file(s): {', '.join(missing_files)}.")
208 elif (
209 not Path(instanovo_plus_model).is_file()
210 and not instanovo_plus_model.startswith("s3://")
211 and instanovo_plus_model not in InstaNovoPlus.get_pretrained()
212 ):
213 raise ValueError(
214 f"InstaNovo+ model ID '{instanovo_plus_model}' is not supported. Currently "
215 "supported value(s): "
216 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}"""
217 )
218 config.instanovo_plus_model = instanovo_plus_model
220 if config.get("instanovo_plus_model", None) is None:
221 raise ValueError(
222 "Expected 'instanovo_plus_model' when refining, but found None. Please specify it "
223 "in the `config/inference/<your_config>.yaml` configuration file or with the cli "
224 "flag `instanovo predict --instanovo-plus-model=path/to/model_dir`."
225 )
227 predictor = CombinedPredictor(config)
228 predictor.predict()
231@combined_cli.command()
232def version() -> None:
233 """Display version information for InstaNovo, Instanovo+ and its dependencies."""
234 table = Table("Package", "Version")
235 table.add_row("InstaNovo", __version__)
236 table.add_row("InstaNovo+", __version__)
237 table.add_row("NumPy", metadata.version("numpy"))
238 table.add_row("PyTorch", metadata.version("torch"))
239 console.print(table)
242def instanovo_entrypoint() -> None:
243 """Main entry point for the InstaNovo CLI application."""
244 combined_cli()
247if __name__ == "__main__":
248 combined_cli()