Coverage for instanovo/cli.py: 83%

87 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import glob 

2import importlib.metadata as metadata 

3from pathlib import Path 

4from typing import List, Optional 

5 

6import typer 

7from rich.table import Table 

8from typing_extensions import Annotated 

9 

10from instanovo import __version__ 

11from instanovo.__init__ import console 

12from instanovo.constants import DEFAULT_INFERENCE_CONFIG_NAME, DEFAULT_INFERENCE_CONFIG_PATH 

13from instanovo.diffusion.cli import cli as diffusion_cli 

14from instanovo.scripts.convert_to_sdf import app as convert_to_sdf_app 

15from instanovo.transformer.cli import cli as transformer_cli 

16from instanovo.utils.cli_utils import compose_config 

17from instanovo.utils.colorlogging import ColorLog 

18 

19logger = ColorLog(console, __name__).logger 

20 

21combined_cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False) 

22combined_cli.add_typer( 

23 transformer_cli, 

24 name="transformer", 

25 help="Run predictions or train with only the transformer-based InstaNovo model.", 

26) 

27combined_cli.add_typer( 

28 diffusion_cli, 

29 name="diffusion", 

30 help="Run predictions or train with only the diffusion-based InstaNovo+ model.", 

31) 

32combined_cli.add_typer(convert_to_sdf_app) 

33 

34 

35@combined_cli.callback(invoke_without_command=True) 

36def main(ctx: typer.Context) -> None: 

37 """Run predictions with InstaNovo and optionally with InstaNovo+. 

38 

39 First with the transformer-based InstaNovo model and then optionally refine 

40 them with the diffusion based InstaNovo+ model. 

41 """ 

42 # If you just run `instanovo` on the command line, show the help 

43 if ctx.invoked_subcommand is None: 

44 typer.echo(ctx.get_help()) 

45 

46 

47@combined_cli.command() 

48def predict( 

49 data_path: Annotated[ 

50 Optional[str], 

51 typer.Option( 

52 "--data-path", 

53 "-d", 

54 help="Path to input data file", 

55 ), 

56 ] = None, 

57 output_path: Annotated[ 

58 Optional[Path], 

59 typer.Option( 

60 "--output-path", 

61 "-o", 

62 help="Path to output file.", 

63 exists=False, 

64 file_okay=True, 

65 dir_okay=False, 

66 ), 

67 ] = None, 

68 instanovo_model: Annotated[ 

69 Optional[str], 

70 typer.Option( 

71 "--instanovo-model", 

72 "-i", 

73 help=( 

74 "Either a model ID or a path to an Instanovo checkpoint file (.ckpt format)" 

75 # "Either a model ID (currently supported: " 

76 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}) """ 

77 # "or a path to an Instanovo checkpoint file (.ckpt format)." 

78 ), 

79 ), 

80 ] = None, 

81 instanovo_plus_model: Annotated[ 

82 Optional[str], 

83 typer.Option( 

84 "--instanovo-plus-model", 

85 "-p", 

86 help=( 

87 "Either a model ID or a path to an Instanovo+ checkpoint file (.ckpt format)" 

88 # "Either a model ID (currently supported: " 

89 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}) """ 

90 # "or a path to an Instanovo+ checkpoint file (.ckpt format)" 

91 ), 

92 ), 

93 ] = None, 

94 denovo: Annotated[ 

95 Optional[bool], 

96 typer.Option( 

97 "--denovo/--evaluation", 

98 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?", 

99 ), 

100 ] = None, 

101 refine: Annotated[ 

102 Optional[bool], 

103 typer.Option( 

104 "--with-refinement/--no-refinement", 

105 help=("Refine the predictions of the transformer-based InstaNovo model with the diffusion-based InstaNovo+ model?"), 

106 ), 

107 ] = True, 

108 config_path: Annotated[ 

109 Optional[str], 

110 typer.Option( 

111 "--config-path", 

112 "-cp", 

113 help="Relative path to config directory.", 

114 ), 

115 ] = None, 

116 config_name: Annotated[ 

117 Optional[str], 

118 typer.Option( 

119 "--config-name", 

120 "-cn", 

121 help="The name of the config (usually the file name without the .yaml extension).", 

122 ), 

123 ] = None, 

124 overrides: Optional[List[str]] = typer.Argument(None, hidden=True), 

125) -> None: 

126 """Run predictions with InstaNovo and optionally refine with InstaNovo+. 

127 

128 First with the transformer-based InstaNovo model and then optionally refine 

129 them with the diffusion based InstaNovo+ model. 

130 """ 

131 # Defer imports to improve cli performance 

132 logger.info("Initializing InstaNovo inference.") 

133 from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus 

134 from instanovo.diffusion.predict import CombinedPredictor 

135 from instanovo.transformer.model import InstaNovo 

136 

137 if config_path is None: 

138 config_path = DEFAULT_INFERENCE_CONFIG_PATH 

139 if config_name is None: 

140 config_name = DEFAULT_INFERENCE_CONFIG_NAME 

141 

142 config = compose_config( 

143 config_path=config_path, 

144 config_name=config_name, 

145 overrides=overrides, 

146 ) 

147 

148 # Check config inputs 

149 if data_path is not None: 

150 if "*" in data_path or "?" in data_path or "[" in data_path: 

151 # Glob notation: path/to/data/*.parquet 

152 if not glob.glob(data_path): 

153 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).") 

154 config.data_path = str(data_path) 

155 

156 if not config.get("data_path", None) and data_path is None: 

157 raise ValueError( 

158 "Expected 'data_path' but found None. Please specify it in the " 

159 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

160 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a " 

161 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`." 

162 ) 

163 

164 if denovo is not None: 

165 config.denovo = denovo 

166 if refine is not None: 

167 config.refine = refine 

168 

169 if output_path is not None: 

170 if output_path.exists(): 

171 logger.info(f"Output path '{output_path}' already exists and will be overwritten.") 

172 output_path.parent.mkdir(parents=True, exist_ok=True) 

173 config.output_path = str(output_path) 

174 if config.output_path and not Path(config.output_path).parent.exists(): 

175 Path(config.output_path).parent.mkdir(parents=True, exist_ok=True) 

176 if config.get("output_path", None) is None and (config.get("denovo", False)): 

177 raise ValueError( 

178 "Expected 'output_path' but found None in denovo mode. Please specify it " 

179 "in the `config/inference/<your_config>.yaml` configuration file or with the cli flag " 

180 "`--output-path=path/to/output_file`." 

181 ) 

182 

183 if instanovo_model is not None: 

184 if Path(instanovo_model).is_file() and Path(instanovo_model).suffix != ".ckpt": 

185 raise ValueError(f"Checkpoint file '{instanovo_model}' should end with extension '.ckpt'.") 

186 if not Path(instanovo_model).is_file() and not instanovo_model.startswith("s3://") and instanovo_model not in InstaNovo.get_pretrained(): 

187 raise ValueError( 

188 f"InstaNovo model ID '{instanovo_model}' is not supported. " 

189 "Currently supported value(s): " 

190 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}""" 

191 ) 

192 config.instanovo_model = instanovo_model 

193 

194 if not config.get("instanovo_model", None): 

195 raise ValueError( 

196 "Expected 'instanovo_model' but found None. Please specify it in the " 

197 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

198 "`instanovo predict --instanovo_model=path/to/model.ckpt`." 

199 ) 

200 

201 if config.refine: 

202 if instanovo_plus_model is not None: 

203 if Path(instanovo_plus_model).is_dir(): 

204 required_files = ["*.ckpt", "*.yaml", "*.pt"] 

205 missing_files = [ext for ext in required_files if not list(Path(instanovo_plus_model).glob(ext))] 

206 if missing_files: 

207 raise ValueError(f"The directory '{instanovo_plus_model}' is missing the following required file(s): {', '.join(missing_files)}.") 

208 elif ( 

209 not Path(instanovo_plus_model).is_file() 

210 and not instanovo_plus_model.startswith("s3://") 

211 and instanovo_plus_model not in InstaNovoPlus.get_pretrained() 

212 ): 

213 raise ValueError( 

214 f"InstaNovo+ model ID '{instanovo_plus_model}' is not supported. Currently " 

215 "supported value(s): " 

216 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}""" 

217 ) 

218 config.instanovo_plus_model = instanovo_plus_model 

219 

220 if config.get("instanovo_plus_model", None) is None: 

221 raise ValueError( 

222 "Expected 'instanovo_plus_model' when refining, but found None. Please specify it " 

223 "in the `config/inference/<your_config>.yaml` configuration file or with the cli " 

224 "flag `instanovo predict --instanovo-plus-model=path/to/model_dir`." 

225 ) 

226 

227 predictor = CombinedPredictor(config) 

228 predictor.predict() 

229 

230 

231@combined_cli.command() 

232def version() -> None: 

233 """Display version information for InstaNovo, Instanovo+ and its dependencies.""" 

234 table = Table("Package", "Version") 

235 table.add_row("InstaNovo", __version__) 

236 table.add_row("InstaNovo+", __version__) 

237 table.add_row("NumPy", metadata.version("numpy")) 

238 table.add_row("PyTorch", metadata.version("torch")) 

239 console.print(table) 

240 

241 

242def instanovo_entrypoint() -> None: 

243 """Main entry point for the InstaNovo CLI application.""" 

244 combined_cli() 

245 

246 

247if __name__ == "__main__": 

248 combined_cli()