Coverage for instanovo/diffusion/predict.py: 84%

196 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from pathlib import Path 

2from typing import Any, Tuple 

3 

4import numpy as np 

5import polars as pl 

6import torch 

7import torch.nn as nn 

8from accelerate.utils import broadcast_object_list 

9from datasets import Dataset 

10from omegaconf import DictConfig, OmegaConf 

11 

12from instanovo.__init__ import console 

13from instanovo.common import AccelerateDeNovoPredictor, DataProcessor 

14from instanovo.constants import ( 

15 DIFFUSION_START_STEP, 

16 REFINEMENT_COLUMN, 

17 REFINEMENT_PROBABILITY_COLUMN, 

18 SpecialTokens, 

19) 

20from instanovo.diffusion.data import DiffusionDataProcessor 

21from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus 

22from instanovo.inference.diffusion import DiffusionDecoder 

23from instanovo.inference.interfaces import Decoder 

24from instanovo.transformer.predict import TransformerPredictor 

25from instanovo.utils.colorlogging import ColorLog 

26 

27logger = ColorLog(console, __name__).logger 

28 

29CONFIG_PATH = Path(__file__).parent.parent / "configs" 

30 

31 

32class DiffusionPredictor(AccelerateDeNovoPredictor): 

33 """Predictor for the InstaNovo+ model.""" 

34 

35 def __init__( 

36 self, 

37 config: DictConfig, 

38 ): 

39 self.refine = config.get("refine", False) 

40 self.refine_all = config.get("refine_all", True) 

41 self.refine_threshold = np.log(config.get("refine_threshold", 0.9)) 

42 self.precursor_tolerance = config.get("filter_precursor_ppm", 50) 

43 super().__init__(config) 

44 

45 # Possibly merge with transformer load_model 

46 def load_model(self) -> Tuple[nn.Module, DictConfig]: 

47 """Setup the model.""" 

48 default_model = InstaNovoPlus.get_pretrained()[0] 

49 model_path = self.config.get("instanovo_plus_model", default_model) 

50 

51 logger.info(f"Loading InstaNovo+ model {model_path}") 

52 if model_path in InstaNovoPlus.get_pretrained(): 

53 # Using a pretrained model from models.json 

54 model, model_config = InstaNovoPlus.from_pretrained( 

55 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None 

56 ) 

57 else: 

58 model_path = self.s3.get_local_path(model_path) 

59 assert model_path is not None 

60 model, model_config = InstaNovoPlus.load( 

61 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None 

62 ) 

63 

64 return model, model_config 

65 

66 def postprocess_dataset(self, dataset: Dataset) -> Dataset: 

67 """Load previous predictions for refinement.""" 

68 if not self.refine: 

69 return dataset 

70 

71 data_path = self.config.get("data_path", None) 

72 prediction_paths = [] 

73 if self.accelerator.is_main_process: 

74 if OmegaConf.is_list(data_path): 

75 # Grouped refinement 

76 for group in data_path: 

77 path = group.get("refinement_path") 

78 if path is None: 

79 raise ValueError("refinement_path must be specified per group when `refine` is True in pipeline mode.") 

80 path = self.s3.get_local_path(path) 

81 prediction_paths.append(path) 

82 else: 

83 path = self.config.get("refinement_path", None) 

84 if path is None: 

85 raise ValueError("refinement_path must be specified when `refine` is True.") 

86 path = self.s3.get_local_path(path) 

87 prediction_paths.append(path) 

88 

89 if self.accelerator.num_processes > 1: 

90 logger.info(f"Broadcasting {len(prediction_paths)} refinement paths") 

91 

92 prediction_paths = broadcast_object_list([prediction_paths], from_process=0)[0] 

93 if self.accelerator.num_processes > 1: 

94 logger.info(f"Received {len(prediction_paths)} paths for refinement") 

95 

96 predictions_id_col = self.config.get("refinement_id_col", "spectrum_id") 

97 dataset_id_col = self.config.get("dataset_id_col", "spectrum_id") 

98 predictions_refine_col = self.config.get("prediction_refine_col", "predictions_tokenised") 

99 prediction_confidence_col = self.config.get("prediction_confidence_col", None) 

100 

101 for path in prediction_paths: 

102 columns = set(pl.scan_csv(path).collect_schema().keys()) 

103 if predictions_id_col not in columns: 

104 raise ValueError(f"Column '{predictions_id_col}' does not exist in {path}.") 

105 if predictions_refine_col not in columns: 

106 raise ValueError(f"Column '{predictions_refine_col}' does not exist in {path}.") 

107 if prediction_confidence_col not in columns and prediction_confidence_col is not None: 

108 raise ValueError(f"Column '{prediction_confidence_col}' does not exist in {path}.") 

109 

110 if dataset_id_col not in dataset.column_names: 

111 raise ValueError(f"Column '{dataset_id_col}' does not exist in dataset.") 

112 

113 target_schema = {predictions_id_col: pl.String, predictions_refine_col: pl.String} 

114 

115 if prediction_confidence_col is not None: 

116 target_schema.update({prediction_confidence_col: pl.Float64}) 

117 

118 logger.info(f"Reading {len(prediction_paths)} refinement file(s)") 

119 predictions_df = pl.concat( 

120 [pl.read_csv(path, columns=list(target_schema.keys()), schema_overrides=target_schema) for path in prediction_paths], 

121 how="vertical", 

122 ) 

123 

124 id_to_predictions = dict( 

125 zip( 

126 predictions_df[predictions_id_col], 

127 predictions_df[predictions_refine_col], 

128 strict=False, 

129 ) 

130 ) 

131 if prediction_confidence_col is not None: 

132 id_to_confidence = dict( 

133 zip( 

134 predictions_df[predictions_id_col], 

135 predictions_df[prediction_confidence_col], 

136 strict=False, 

137 ) 

138 ) 

139 else: 

140 logger.warning("'prediction_confidence_col' not set. Setting all input confidence scores to 0.") 

141 

142 def add_predictions_column(row: dict[str, Any]) -> dict[str, Any]: 

143 prediction = id_to_predictions.get(row[dataset_id_col], None) 

144 row[REFINEMENT_COLUMN] = self._clean_predictions(prediction) 

145 if prediction_confidence_col is not None: 

146 row[REFINEMENT_PROBABILITY_COLUMN] = id_to_confidence.get(row[dataset_id_col], None) 

147 else: 

148 row[REFINEMENT_PROBABILITY_COLUMN] = 0 

149 return row 

150 

151 logger.info("Adding refinement columns to dataset") 

152 dataset = dataset.map(add_predictions_column) 

153 

154 num_none_refine = sum(1 for x in dataset[REFINEMENT_COLUMN] if x is None) 

155 if num_none_refine > 0: 

156 logger.info(f"Refinement is missing for {num_none_refine} / {len(dataset)} spectra ({((num_none_refine / len(dataset)) * 100):.2f}%)") 

157 

158 return dataset 

159 

160 def _clean_predictions(self, predictions: str | None) -> str: 

161 if predictions is None: 

162 return "" 

163 # Replace invalid tokens with PAD token 

164 tokens = self.model.residue_set.tokenize(predictions) 

165 tokens = [token if token in self.model.residue_set.vocab else SpecialTokens.PAD_TOKEN.value for token in tokens] 

166 return ", ".join(tokens) 

167 

168 def setup_data_processor(self) -> DataProcessor: 

169 """Setup the data processor.""" 

170 processor = DiffusionDataProcessor( 

171 self.residue_set, 

172 n_peaks=self.model_config.get("n_peaks", 200), 

173 min_mz=self.model_config.get("min_mz", 50.0), 

174 max_mz=self.model_config.get("max_mz", 2500.0), 

175 min_intensity=self.model_config.get("min_intensity", 0.01), 

176 remove_precursor_tol=self.model_config.get("remove_precursor_tol", 2.0), 

177 return_str=False, 

178 reverse_peptide=False, 

179 add_eos=False, 

180 peptide_pad_length=self.model_config.get("max_length", 40), 

181 peptide_pad_value=self.residue_set.PAD_INDEX, 

182 use_spectrum_utils=False, 

183 annotated=not self.denovo, 

184 metadata_columns=["group"], 

185 ) 

186 

187 if self.refine: 

188 processor.add_metadata_columns([REFINEMENT_COLUMN, REFINEMENT_PROBABILITY_COLUMN]) 

189 

190 return processor 

191 

192 def setup_decoder(self) -> Decoder: 

193 """Setup the decoder.""" 

194 return DiffusionDecoder(model=self.model) # type: ignore 

195 

196 def get_predictions(self, batch: Any) -> dict[str, Any]: 

197 """Get the predictions for a batch.""" 

198 num_beams = self.config.get("num_beams", 1) 

199 batch_size = batch["spectra"].size(0) 

200 

201 batch_results: dict[str, Any] = self.decoder.decode( 

202 initial_sequence=batch[REFINEMENT_COLUMN] if self.refine else None, 

203 spectra=batch["spectra"], 

204 precursors=batch["precursors"], 

205 spectra_padding_mask=batch["spectra_mask"], 

206 start_step=DIFFUSION_START_STEP if self.refine else None, # type: ignore 

207 beam_size=num_beams, 

208 return_encoder_output=self.save_encoder_outputs, 

209 encoder_output_reduction=self.encoder_output_reduction, 

210 ) # type: ignore 

211 

212 if "peptides" in batch: 

213 targets = [self.residue_set.decode(seq, reverse=False) for seq in batch["peptides"]] 

214 else: 

215 targets = [None] * batch_size 

216 

217 batch_results["targets"] = targets 

218 

219 if self.refine and not self.refine_all: 

220 unrefined_preds = [self.residue_set.decode(seq, reverse=False) for seq in batch[REFINEMENT_COLUMN]] 

221 batch_results["unrefined_predictions"] = unrefined_preds 

222 unrefined_matches = [] 

223 for i in range(batch_size): 

224 matches, _ = self.metrics.matches_precursor( 

225 unrefined_preds[i], 

226 batch["precursors"][i][2], 

227 batch["precursors"][i][1], 

228 prec_tol=self.precursor_tolerance, 

229 ) 

230 unrefined_matches.append(matches) 

231 

232 for i in range(batch_size): 

233 refine_prob = batch[REFINEMENT_PROBABILITY_COLUMN][i].item() 

234 if self.refine_threshold is not None: 

235 if ( 

236 refine_prob < self.refine_threshold 

237 and unrefined_matches[i] > batch_results["meets_precursor"][i] 

238 and refine_prob > batch_results["prediction_log_probability"][i] 

239 ): 

240 batch_results["meets_precursor"][i] = unrefined_matches[i] 

241 batch_results["predictions"][i] = unrefined_preds[i] 

242 batch_results["prediction_log_probability"][i] = refine_prob 

243 else: 

244 # Ensemble based on ppm match 

245 if unrefined_matches[i] > batch_results["meets_precursor"][i]: 

246 batch_results["meets_precursor"][i] = unrefined_matches[i] 

247 batch_results["predictions"][i] = unrefined_preds[i] 

248 batch_results["prediction_log_probability"][i] = refine_prob 

249 

250 batch_results.pop("meets_precursor", None) 

251 

252 return batch_results 

253 

254 

255class CombinedPredictor(TransformerPredictor): 

256 """Predictor for the combined InstaNovo+ model.""" 

257 

258 diffusion_load_model = DiffusionPredictor.load_model 

259 diffusion_get_predictions = DiffusionPredictor.get_predictions 

260 

261 def __init__(self, config: DictConfig): 

262 self.refine = config.get("refine", False) 

263 self.refine_all = config.get("refine_all", True) 

264 self.refine_threshold = np.log(config.get("refine_threshold", 0.9)) 

265 self.precursor_tolerance = config.get("filter_precursor_ppm", 50) 

266 super().__init__(config) 

267 

268 # Manually prepare the diffusion model since it is not prepared in the parent class 

269 if self.refine: 

270 logger.info("Running in refinement mode.") 

271 self.diffusion_model: nn.Module = self.accelerator.prepare(self.diffusion_model) 

272 

273 def load_model(self) -> Tuple[nn.Module, DictConfig]: 

274 """Setup the model.""" 

275 self.transformer_model, transformer_model_config = super().load_model() 

276 

277 if not self.refine: 

278 return self.transformer_model, transformer_model_config 

279 

280 self.diffusion_model, diffusion_model_config = self.diffusion_load_model() # type: ignore 

281 

282 self.diffusion_residue_set = self.diffusion_model.residue_set 

283 

284 # Compare residue sets 

285 transformer_residue_set = self.transformer_model.residue_set.index_to_residue 

286 diffusion_residue_set = self.diffusion_residue_set.index_to_residue 

287 if transformer_residue_set != diffusion_residue_set: 

288 raise ValueError("Transformer and diffusion residue sets do not match") 

289 

290 # Compare max length 

291 self.transformer_max_length = transformer_model_config.get("max_length", 40) 

292 self.diffusion_max_length = diffusion_model_config.get("max_length", 40) 

293 if self.transformer_max_length != self.diffusion_max_length: 

294 logger.warning( 

295 f"Transformer and diffusion max length do not match. " 

296 f"Transformer: {self.transformer_max_length}, " 

297 f"Diffusion: {self.diffusion_max_length}" 

298 ) 

299 

300 return self.transformer_model, transformer_model_config 

301 

302 def setup_decoder(self) -> Decoder: 

303 """Setup the decoder.""" 

304 # Diffusion decoder 

305 if self.refine: 

306 self.diffusion_decoder = DiffusionDecoder(model=self.diffusion_model) 

307 else: 

308 self.diffusion_decoder = None # type: ignore 

309 

310 # Transformer decoder 

311 self.transformer_decoder = super().setup_decoder() 

312 return self.transformer_decoder # type: ignore 

313 

314 def _tokenize_and_pad(self, refinement: list[str]) -> torch.Tensor: 

315 """Tokenize and pad the transformer predictions.""" 

316 encodings = [] 

317 for refine in refinement: 

318 refine_tokenized = self.diffusion_residue_set.tokenize(refine) 

319 

320 refine_encoding = self.diffusion_residue_set.encode(refine_tokenized, add_eos=False, return_tensor="pt") 

321 

322 refine_encoding = refine_encoding[: self.diffusion_max_length] 

323 

324 # Diffusion always padded to fixed length 

325 refine_padded = torch.full( 

326 (max(self.diffusion_max_length, refine_encoding.shape[0]),), 

327 fill_value=self.diffusion_residue_set.PAD_INDEX, 

328 dtype=refine_encoding.dtype, 

329 device=refine_encoding.device, 

330 ) 

331 refine_padded[: refine_encoding.shape[0]] = refine_encoding 

332 

333 encodings.append(refine_padded) 

334 

335 encodings, _ = DiffusionDataProcessor._pad_and_mask(encodings) 

336 return encodings 

337 

338 def get_predictions(self, batch: Any) -> dict[str, Any]: 

339 """Get the predictions for a batch.""" 

340 # set self.model to use the correct model 

341 self.decoder = self.transformer_decoder 

342 transformer_predictions = super().get_predictions(batch) 

343 

344 if not self.refine: 

345 return transformer_predictions # type: ignore 

346 

347 batch[REFINEMENT_COLUMN] = self._tokenize_and_pad(transformer_predictions["predictions"]).to(self.accelerator.device) 

348 batch[REFINEMENT_PROBABILITY_COLUMN] = torch.tensor(transformer_predictions["prediction_log_probability"]).to(self.accelerator.device) 

349 

350 self.decoder = self.diffusion_decoder # type: ignore 

351 diffusion_predictions = self.diffusion_get_predictions(batch) # type: ignore 

352 

353 predictions = { 

354 "predictions": diffusion_predictions["predictions"], 

355 "prediction_log_probability": diffusion_predictions["prediction_log_probability"], 

356 "prediction_token_log_probabilities": diffusion_predictions["prediction_token_log_probabilities"], 

357 "targets": transformer_predictions["targets"], 

358 } 

359 

360 # Don't need to keep these 

361 transformer_predictions.pop("targets") 

362 diffusion_predictions.pop("targets") 

363 

364 predictions.update({f"instanovo_{k}": v for k, v in transformer_predictions.items()}) 

365 predictions.update({f"instanovoplus_{k}": v for k, v in diffusion_predictions.items()}) 

366 

367 return predictions