Coverage for instanovo/diffusion/cli.py: 72%

68 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import glob 

2from pathlib import Path 

3from typing import List, Optional 

4 

5import typer 

6from typing_extensions import Annotated 

7 

8from instanovo.__init__ import console 

9from instanovo.constants import ( 

10 DEFAULT_INFERENCE_CONFIG_NAME, 

11 DEFAULT_INFERENCE_CONFIG_PATH, 

12 DEFAULT_TRAIN_CONFIG_PATH, 

13) 

14from instanovo.utils.cli_utils import compose_config 

15from instanovo.utils.colorlogging import ColorLog 

16 

17logger = ColorLog(console, __name__).logger 

18 

19cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False) 

20 

21 

22@cli.command("train") 

23def diffusion_train( 

24 config_path: Annotated[ 

25 Optional[str], 

26 typer.Option( 

27 "--config-path", 

28 "-cp", 

29 help="Relative path to config directory.", 

30 ), 

31 ] = None, 

32 config_name: Annotated[ 

33 Optional[str], 

34 typer.Option( 

35 "--config-name", 

36 "-cn", 

37 help="The name of the config (usually the file name without the .yaml extension).", 

38 ), 

39 ] = None, 

40 overrides: Optional[List[str]] = typer.Argument(None, hidden=True), 

41) -> None: 

42 """Train the InstaNovo+ model.""" 

43 logger.info("Initializing InstaNovo+ training.") 

44 from instanovo.diffusion.train import DiffusionTrainer 

45 

46 if config_path is None: 

47 config_path = DEFAULT_TRAIN_CONFIG_PATH 

48 if config_name is None: 

49 config_name = "instanovoplus" 

50 

51 config = compose_config( 

52 config_path=config_path, 

53 config_name=config_name, 

54 overrides=overrides, 

55 ) 

56 

57 logger.info("Initializing diffusion training.") 

58 trainer = DiffusionTrainer(config) 

59 trainer.train() 

60 

61 

62@cli.command("predict") 

63def diffusion_predict( 

64 data_path: Annotated[ 

65 Optional[str], 

66 typer.Option( 

67 "--data-path", 

68 "-d", 

69 help="Path to input data file", 

70 ), 

71 ] = None, 

72 output_path: Annotated[ 

73 Optional[Path], 

74 typer.Option( 

75 "--output-path", 

76 "-o", 

77 help="Path to output file.", 

78 exists=False, 

79 file_okay=True, 

80 dir_okay=False, 

81 ), 

82 ] = None, 

83 instanovo_plus_model: Annotated[ 

84 Optional[str], 

85 typer.Option( 

86 "--instanovo-plus-model", 

87 "-p", 

88 help="Either a model ID or a path to an Instanovo+ checkpoint file (.ckpt format)", 

89 # help="Either a model ID (currently supported: " 

90 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())})""" 

91 # " or a path to an Instanovo+ checkpoint file (.ckpt format)", 

92 ), 

93 ] = None, 

94 denovo: Annotated[ 

95 Optional[bool], 

96 typer.Option( 

97 "--denovo/--evaluation", 

98 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?", 

99 ), 

100 ] = None, 

101 refine: Annotated[ 

102 Optional[bool], 

103 typer.Option( 

104 "--with-refinement/--no-refinement", 

105 help="Refine the predictions of the transformer-based InstaNovo model with the diffusion-based InstaNovo+ model?", 

106 ), 

107 ] = None, 

108 config_path: Annotated[ 

109 Optional[str], 

110 typer.Option( 

111 "--config-path", 

112 "-cp", 

113 help="Relative path to config directory.", 

114 ), 

115 ] = None, 

116 config_name: Annotated[ 

117 Optional[str], 

118 typer.Option( 

119 "--config-name", 

120 "-cn", 

121 help="The name of the config (usually the file name without the .yaml extension).", 

122 ), 

123 ] = None, 

124 overrides: Optional[List[str]] = typer.Argument(None, hidden=True), 

125) -> None: 

126 """Run predictions with InstaNovo+.""" 

127 logger.info("Initializing InstaNovo+ inference.") 

128 # Defer imports to improve cli performance 

129 from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus 

130 from instanovo.diffusion.predict import DiffusionPredictor 

131 

132 if config_path is None: 

133 config_path = DEFAULT_INFERENCE_CONFIG_PATH 

134 if config_name is None: 

135 config_name = DEFAULT_INFERENCE_CONFIG_NAME 

136 

137 config = compose_config( 

138 config_path=config_path, 

139 config_name=config_name, 

140 overrides=overrides, 

141 ) 

142 

143 # Check config inputs 

144 if data_path is not None: 

145 if "*" in data_path or "?" in data_path or "[" in data_path: 

146 # Glob notation: path/to/data/*.parquet 

147 if not glob.glob(data_path): 

148 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).") 

149 config.data_path = str(data_path) 

150 

151 if not config.get("data_path", None) and data_path is None: 

152 raise ValueError( 

153 "Expected 'data_path' but found None. Please specify it in the " 

154 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

155 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a " 

156 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`." 

157 ) 

158 

159 if denovo is not None: 

160 config.denovo = denovo 

161 if refine is not None: 

162 config.refine = refine 

163 

164 if output_path is not None: 

165 if output_path.exists(): 

166 logger.info(f"Output path '{output_path}' already exists and will be overwritten.") 

167 output_path.parent.mkdir(parents=True, exist_ok=True) 

168 config.output_path = str(output_path) 

169 if config.output_path and not Path(config.output_path).parent.exists(): 

170 Path(config.output_path).parent.mkdir(parents=True, exist_ok=True) 

171 if config.get("output_path", None) is None and (config.get("denovo", False)): 

172 raise ValueError( 

173 "Expected 'output_path' but found None in denovo mode. Please specify it " 

174 "in the `config/inference/<your_config>.yaml` configuration file or with the cli flag " 

175 "`--output-path=path/to/output_file`." 

176 ) 

177 

178 if instanovo_plus_model is not None: 

179 if Path(instanovo_plus_model).is_dir(): 

180 required_files = ["*.ckpt", "*.yaml", "*.pt"] 

181 missing_files = [ext for ext in required_files if not list(Path(instanovo_plus_model).glob(ext))] 

182 if missing_files: 

183 raise ValueError(f"The directory '{instanovo_plus_model}' is missing the following required file(s): {', '.join(missing_files)}.") 

184 elif ( 

185 not Path(instanovo_plus_model).is_file() 

186 and not instanovo_plus_model.startswith("s3://") 

187 and instanovo_plus_model not in InstaNovoPlus.get_pretrained() 

188 ): 

189 raise ValueError( 

190 f"InstaNovo+ model ID '{instanovo_plus_model}' is not supported. Currently " 

191 "supported value(s): " 

192 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovoPlus.get_pretrained())}""" 

193 ) 

194 config.instanovo_plus_model = instanovo_plus_model 

195 

196 if not config.get("instanovo_plus_model", None): 

197 raise ValueError( 

198 "Expected 'instanovo_plus_model' but found None. Please specify it in the " 

199 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

200 "`instanovo diffusion predict --instanovo-plus-model=path/to/model_dir`." 

201 ) 

202 

203 if ( 

204 config.get("refine", False) 

205 and config.get("refinement_path", None) == config.get("output_path", None) 

206 and config.get("refinement_path", None) is not None 

207 ): 

208 raise ValueError("The 'refinement_path' should be different from the 'output_path' to avoid overwriting the original predictions.") 

209 

210 predictor = DiffusionPredictor(config) 

211 predictor.predict()