Coverage for instanovo/transformer/cli.py: 79%

61 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import glob 

2from pathlib import Path 

3from typing import List, Optional 

4 

5import typer 

6from omegaconf import DictConfig 

7from typing_extensions import Annotated 

8 

9from instanovo.__init__ import console 

10from instanovo.constants import ( 

11 DEFAULT_INFERENCE_CONFIG_NAME, 

12 DEFAULT_INFERENCE_CONFIG_PATH, 

13 DEFAULT_TRAIN_CONFIG_PATH, 

14) 

15from instanovo.utils.cli_utils import compose_config 

16from instanovo.utils.colorlogging import ColorLog 

17 

18logger = ColorLog(console, __name__).logger 

19 

20cli = typer.Typer(rich_markup_mode="rich", pretty_exceptions_enable=False) 

21 

22 

23@cli.command("train") 

24def transformer_train( 

25 config_path: Annotated[ 

26 Optional[str], 

27 typer.Option( 

28 "--config-path", 

29 "-cp", 

30 help="Relative path to config directory.", 

31 ), 

32 ] = None, 

33 config_name: Annotated[ 

34 Optional[str], 

35 typer.Option( 

36 "--config-name", 

37 "-cn", 

38 help="The name of the config (usually the file name without the .yaml extension).", 

39 ), 

40 ] = None, 

41 overrides: Optional[List[str]] = typer.Argument(None, hidden=True), 

42) -> None: 

43 """Train the InstaNovo model.""" 

44 logger.info("Initializing InstaNovo training.") 

45 

46 # Defer imports to improve cli performance 

47 from instanovo.transformer.train import TransformerTrainer 

48 

49 if config_path is None: 

50 config_path = DEFAULT_TRAIN_CONFIG_PATH 

51 if config_name is None: 

52 config_name = "instanovo" 

53 

54 config = compose_config( 

55 config_path=config_path, 

56 config_name=config_name, 

57 overrides=overrides, 

58 ) 

59 

60 logger.info("Initializing InstaNovo training.") 

61 trainer = TransformerTrainer(config) 

62 trainer.train() 

63 

64 

65@cli.command("predict") 

66def transformer_predict( 

67 data_path: Annotated[ 

68 Optional[str], 

69 typer.Option( 

70 "--data-path", 

71 "-d", 

72 help="Path to input data file", 

73 ), 

74 ] = None, 

75 output_path: Annotated[ 

76 Optional[Path], 

77 typer.Option( 

78 "--output-path", 

79 "-o", 

80 help="Path to output file.", 

81 exists=False, 

82 file_okay=True, 

83 dir_okay=False, 

84 ), 

85 ] = None, 

86 instanovo_model: Annotated[ 

87 Optional[str], 

88 typer.Option( 

89 "--instanovo-model", 

90 "-i", 

91 help=( 

92 "Either a model ID or a path to an Instanovo checkpoint file (.ckpt format)." 

93 # Removed: expensive in in CLI, TODO: explore re-adding later 

94 # "Either a model ID (currently supported: " 

95 # f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())})""" 

96 # " or a path to an Instanovo checkpoint file (.ckpt format)." 

97 ), 

98 ), 

99 ] = None, 

100 denovo: Annotated[ 

101 Optional[bool], 

102 typer.Option( 

103 "--denovo/--evaluation", 

104 help="Do [i]de novo[/i] predictions or evaluate an annotated file with peptide sequences?", 

105 ), 

106 ] = None, 

107 config_path: Annotated[ 

108 Optional[str], 

109 typer.Option( 

110 "--config-path", 

111 "-cp", 

112 help="Relative path to config directory.", 

113 ), 

114 ] = None, 

115 config_name: Annotated[ 

116 Optional[str], 

117 typer.Option( 

118 "--config-name", 

119 "-cn", 

120 help="The name of the config (usually the file name without the .yaml extension).", 

121 ), 

122 ] = None, 

123 overrides: Optional[List[str]] = typer.Argument(None, hidden=True), 

124) -> DictConfig: 

125 """Run predictions with InstaNovo.""" 

126 # Compose config with overrides 

127 logger.info("Initializing InstaNovo inference.") 

128 

129 # Defer imports to improve cli performance 

130 from instanovo.transformer.model import InstaNovo 

131 from instanovo.transformer.predict import TransformerPredictor 

132 

133 if config_path is None: 

134 config_path = DEFAULT_INFERENCE_CONFIG_PATH 

135 if config_name is None: 

136 config_name = DEFAULT_INFERENCE_CONFIG_NAME 

137 

138 config = compose_config( 

139 config_path=config_path, 

140 config_name=config_name, 

141 overrides=overrides, 

142 ) 

143 

144 # Check config inputs 

145 if data_path is not None: 

146 if "*" in data_path or "?" in data_path or "[" in data_path: 

147 # Glob notation: path/to/data/*.parquet 

148 if not glob.glob(data_path): 

149 raise ValueError(f"The data_path '{data_path}' doesn't correspond to any file(s).") 

150 config.data_path = str(data_path) 

151 

152 if not config.get("data_path", None) and data_path is None: 

153 raise ValueError( 

154 "Expected 'data_path' but found None. Please specify it in the " 

155 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

156 "`--data-path='path/to/data'`. Allows `.mgf`, `.mzml`, `.mzxml`, a directory, or a " 

157 "`.parquet` file. Glob notation is supported: eg.: `--data-path='./experiment/*.mgf'`." 

158 ) 

159 

160 if denovo is not None: 

161 # Don't compute metrics in denovo mode 

162 config.denovo = denovo 

163 

164 if output_path is not None: 

165 if output_path.exists(): 

166 logger.info(f"Output path '{output_path}' already exists and will be overwritten.") 

167 output_path.parent.mkdir(parents=True, exist_ok=True) 

168 config.output_path = str(output_path) 

169 if config.get("output_path", None) is None and config.get("denovo", False): 

170 raise ValueError( 

171 "Expected 'output_path' but found None in denovo mode. Please specify it in the " 

172 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

173 "`--output-path=path/to/output_file`." 

174 ) 

175 

176 if instanovo_model is not None: 

177 if Path(instanovo_model).is_file() and Path(instanovo_model).suffix != ".ckpt": 

178 raise ValueError(f"Checkpoint file '{instanovo_model}' should end with extension '.ckpt'.") 

179 if not Path(instanovo_model).is_file() and not instanovo_model.startswith("s3://") and instanovo_model not in InstaNovo.get_pretrained(): 

180 raise ValueError( 

181 f"InstaNovo model ID '{instanovo_model}' is not supported. " 

182 "Currently supported value(s): " 

183 f"""{", ".join(f"'{model_id}'" for model_id in InstaNovo.get_pretrained())}""" 

184 ) 

185 config.instanovo_model = instanovo_model 

186 

187 if not config.get("instanovo_model", None): 

188 raise ValueError( 

189 "Expected 'instanovo_model' but found None. Please specify it in the " 

190 "`config/inference/<your_config>.yaml` configuration file or with the cli flag " 

191 "`instanovo transformer predict --instanovo_model=path/to/model.ckpt`." 

192 ) 

193 

194 logger.info("Initializing InstaNovo inference.") 

195 predictor = TransformerPredictor(config) 

196 predictor.predict()