Coverage for instanovo/diffusion/predict.py: 84%
196 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from pathlib import Path
2from typing import Any, Tuple
4import numpy as np
5import polars as pl
6import torch
7import torch.nn as nn
8from accelerate.utils import broadcast_object_list
9from datasets import Dataset
10from omegaconf import DictConfig, OmegaConf
12from instanovo.__init__ import console
13from instanovo.common import AccelerateDeNovoPredictor, DataProcessor
14from instanovo.constants import (
15 DIFFUSION_START_STEP,
16 REFINEMENT_COLUMN,
17 REFINEMENT_PROBABILITY_COLUMN,
18 SpecialTokens,
19)
20from instanovo.diffusion.data import DiffusionDataProcessor
21from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus
22from instanovo.inference.diffusion import DiffusionDecoder
23from instanovo.inference.interfaces import Decoder
24from instanovo.transformer.predict import TransformerPredictor
25from instanovo.utils.colorlogging import ColorLog
27logger = ColorLog(console, __name__).logger
29CONFIG_PATH = Path(__file__).parent.parent / "configs"
32class DiffusionPredictor(AccelerateDeNovoPredictor):
33 """Predictor for the InstaNovo+ model."""
35 def __init__(
36 self,
37 config: DictConfig,
38 ):
39 self.refine = config.get("refine", False)
40 self.refine_all = config.get("refine_all", True)
41 self.refine_threshold = np.log(config.get("refine_threshold", 0.9))
42 self.precursor_tolerance = config.get("filter_precursor_ppm", 50)
43 super().__init__(config)
45 # Possibly merge with transformer load_model
46 def load_model(self) -> Tuple[nn.Module, DictConfig]:
47 """Setup the model."""
48 default_model = InstaNovoPlus.get_pretrained()[0]
49 model_path = self.config.get("instanovo_plus_model", default_model)
51 logger.info(f"Loading InstaNovo+ model {model_path}")
52 if model_path in InstaNovoPlus.get_pretrained():
53 # Using a pretrained model from models.json
54 model, model_config = InstaNovoPlus.from_pretrained(
55 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None
56 )
57 else:
58 model_path = self.s3.get_local_path(model_path)
59 assert model_path is not None
60 model, model_config = InstaNovoPlus.load(
61 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None
62 )
64 return model, model_config
66 def postprocess_dataset(self, dataset: Dataset) -> Dataset:
67 """Load previous predictions for refinement."""
68 if not self.refine:
69 return dataset
71 data_path = self.config.get("data_path", None)
72 prediction_paths = []
73 if self.accelerator.is_main_process:
74 if OmegaConf.is_list(data_path):
75 # Grouped refinement
76 for group in data_path:
77 path = group.get("refinement_path")
78 if path is None:
79 raise ValueError("refinement_path must be specified per group when `refine` is True in pipeline mode.")
80 path = self.s3.get_local_path(path)
81 prediction_paths.append(path)
82 else:
83 path = self.config.get("refinement_path", None)
84 if path is None:
85 raise ValueError("refinement_path must be specified when `refine` is True.")
86 path = self.s3.get_local_path(path)
87 prediction_paths.append(path)
89 if self.accelerator.num_processes > 1:
90 logger.info(f"Broadcasting {len(prediction_paths)} refinement paths")
92 prediction_paths = broadcast_object_list([prediction_paths], from_process=0)[0]
93 if self.accelerator.num_processes > 1:
94 logger.info(f"Received {len(prediction_paths)} paths for refinement")
96 predictions_id_col = self.config.get("refinement_id_col", "spectrum_id")
97 dataset_id_col = self.config.get("dataset_id_col", "spectrum_id")
98 predictions_refine_col = self.config.get("prediction_refine_col", "predictions_tokenised")
99 prediction_confidence_col = self.config.get("prediction_confidence_col", None)
101 for path in prediction_paths:
102 columns = set(pl.scan_csv(path).collect_schema().keys())
103 if predictions_id_col not in columns:
104 raise ValueError(f"Column '{predictions_id_col}' does not exist in {path}.")
105 if predictions_refine_col not in columns:
106 raise ValueError(f"Column '{predictions_refine_col}' does not exist in {path}.")
107 if prediction_confidence_col not in columns and prediction_confidence_col is not None:
108 raise ValueError(f"Column '{prediction_confidence_col}' does not exist in {path}.")
110 if dataset_id_col not in dataset.column_names:
111 raise ValueError(f"Column '{dataset_id_col}' does not exist in dataset.")
113 target_schema = {predictions_id_col: pl.String, predictions_refine_col: pl.String}
115 if prediction_confidence_col is not None:
116 target_schema.update({prediction_confidence_col: pl.Float64})
118 logger.info(f"Reading {len(prediction_paths)} refinement file(s)")
119 predictions_df = pl.concat(
120 [pl.read_csv(path, columns=list(target_schema.keys()), schema_overrides=target_schema) for path in prediction_paths],
121 how="vertical",
122 )
124 id_to_predictions = dict(
125 zip(
126 predictions_df[predictions_id_col],
127 predictions_df[predictions_refine_col],
128 strict=False,
129 )
130 )
131 if prediction_confidence_col is not None:
132 id_to_confidence = dict(
133 zip(
134 predictions_df[predictions_id_col],
135 predictions_df[prediction_confidence_col],
136 strict=False,
137 )
138 )
139 else:
140 logger.warning("'prediction_confidence_col' not set. Setting all input confidence scores to 0.")
142 def add_predictions_column(row: dict[str, Any]) -> dict[str, Any]:
143 prediction = id_to_predictions.get(row[dataset_id_col], None)
144 row[REFINEMENT_COLUMN] = self._clean_predictions(prediction)
145 if prediction_confidence_col is not None:
146 row[REFINEMENT_PROBABILITY_COLUMN] = id_to_confidence.get(row[dataset_id_col], None)
147 else:
148 row[REFINEMENT_PROBABILITY_COLUMN] = 0
149 return row
151 logger.info("Adding refinement columns to dataset")
152 dataset = dataset.map(add_predictions_column)
154 num_none_refine = sum(1 for x in dataset[REFINEMENT_COLUMN] if x is None)
155 if num_none_refine > 0:
156 logger.info(f"Refinement is missing for {num_none_refine} / {len(dataset)} spectra ({((num_none_refine / len(dataset)) * 100):.2f}%)")
158 return dataset
160 def _clean_predictions(self, predictions: str | None) -> str:
161 if predictions is None:
162 return ""
163 # Replace invalid tokens with PAD token
164 tokens = self.model.residue_set.tokenize(predictions)
165 tokens = [token if token in self.model.residue_set.vocab else SpecialTokens.PAD_TOKEN.value for token in tokens]
166 return ", ".join(tokens)
168 def setup_data_processor(self) -> DataProcessor:
169 """Setup the data processor."""
170 processor = DiffusionDataProcessor(
171 self.residue_set,
172 n_peaks=self.model_config.get("n_peaks", 200),
173 min_mz=self.model_config.get("min_mz", 50.0),
174 max_mz=self.model_config.get("max_mz", 2500.0),
175 min_intensity=self.model_config.get("min_intensity", 0.01),
176 remove_precursor_tol=self.model_config.get("remove_precursor_tol", 2.0),
177 return_str=False,
178 reverse_peptide=False,
179 add_eos=False,
180 peptide_pad_length=self.model_config.get("max_length", 40),
181 peptide_pad_value=self.residue_set.PAD_INDEX,
182 use_spectrum_utils=False,
183 annotated=not self.denovo,
184 metadata_columns=["group"],
185 )
187 if self.refine:
188 processor.add_metadata_columns([REFINEMENT_COLUMN, REFINEMENT_PROBABILITY_COLUMN])
190 return processor
192 def setup_decoder(self) -> Decoder:
193 """Setup the decoder."""
194 return DiffusionDecoder(model=self.model) # type: ignore
196 def get_predictions(self, batch: Any) -> dict[str, Any]:
197 """Get the predictions for a batch."""
198 num_beams = self.config.get("num_beams", 1)
199 batch_size = batch["spectra"].size(0)
201 batch_results: dict[str, Any] = self.decoder.decode(
202 initial_sequence=batch[REFINEMENT_COLUMN] if self.refine else None,
203 spectra=batch["spectra"],
204 precursors=batch["precursors"],
205 spectra_padding_mask=batch["spectra_mask"],
206 start_step=DIFFUSION_START_STEP if self.refine else None, # type: ignore
207 beam_size=num_beams,
208 return_encoder_output=self.save_encoder_outputs,
209 encoder_output_reduction=self.encoder_output_reduction,
210 ) # type: ignore
212 if "peptides" in batch:
213 targets = [self.residue_set.decode(seq, reverse=False) for seq in batch["peptides"]]
214 else:
215 targets = [None] * batch_size
217 batch_results["targets"] = targets
219 if self.refine and not self.refine_all:
220 unrefined_preds = [self.residue_set.decode(seq, reverse=False) for seq in batch[REFINEMENT_COLUMN]]
221 batch_results["unrefined_predictions"] = unrefined_preds
222 unrefined_matches = []
223 for i in range(batch_size):
224 matches, _ = self.metrics.matches_precursor(
225 unrefined_preds[i],
226 batch["precursors"][i][2],
227 batch["precursors"][i][1],
228 prec_tol=self.precursor_tolerance,
229 )
230 unrefined_matches.append(matches)
232 for i in range(batch_size):
233 refine_prob = batch[REFINEMENT_PROBABILITY_COLUMN][i].item()
234 if self.refine_threshold is not None:
235 if (
236 refine_prob < self.refine_threshold
237 and unrefined_matches[i] > batch_results["meets_precursor"][i]
238 and refine_prob > batch_results["prediction_log_probability"][i]
239 ):
240 batch_results["meets_precursor"][i] = unrefined_matches[i]
241 batch_results["predictions"][i] = unrefined_preds[i]
242 batch_results["prediction_log_probability"][i] = refine_prob
243 else:
244 # Ensemble based on ppm match
245 if unrefined_matches[i] > batch_results["meets_precursor"][i]:
246 batch_results["meets_precursor"][i] = unrefined_matches[i]
247 batch_results["predictions"][i] = unrefined_preds[i]
248 batch_results["prediction_log_probability"][i] = refine_prob
250 batch_results.pop("meets_precursor", None)
252 return batch_results
255class CombinedPredictor(TransformerPredictor):
256 """Predictor for the combined InstaNovo+ model."""
258 diffusion_load_model = DiffusionPredictor.load_model
259 diffusion_get_predictions = DiffusionPredictor.get_predictions
261 def __init__(self, config: DictConfig):
262 self.refine = config.get("refine", False)
263 self.refine_all = config.get("refine_all", True)
264 self.refine_threshold = np.log(config.get("refine_threshold", 0.9))
265 self.precursor_tolerance = config.get("filter_precursor_ppm", 50)
266 super().__init__(config)
268 # Manually prepare the diffusion model since it is not prepared in the parent class
269 if self.refine:
270 logger.info("Running in refinement mode.")
271 self.diffusion_model: nn.Module = self.accelerator.prepare(self.diffusion_model)
273 def load_model(self) -> Tuple[nn.Module, DictConfig]:
274 """Setup the model."""
275 self.transformer_model, transformer_model_config = super().load_model()
277 if not self.refine:
278 return self.transformer_model, transformer_model_config
280 self.diffusion_model, diffusion_model_config = self.diffusion_load_model() # type: ignore
282 self.diffusion_residue_set = self.diffusion_model.residue_set
284 # Compare residue sets
285 transformer_residue_set = self.transformer_model.residue_set.index_to_residue
286 diffusion_residue_set = self.diffusion_residue_set.index_to_residue
287 if transformer_residue_set != diffusion_residue_set:
288 raise ValueError("Transformer and diffusion residue sets do not match")
290 # Compare max length
291 self.transformer_max_length = transformer_model_config.get("max_length", 40)
292 self.diffusion_max_length = diffusion_model_config.get("max_length", 40)
293 if self.transformer_max_length != self.diffusion_max_length:
294 logger.warning(
295 f"Transformer and diffusion max length do not match. "
296 f"Transformer: {self.transformer_max_length}, "
297 f"Diffusion: {self.diffusion_max_length}"
298 )
300 return self.transformer_model, transformer_model_config
302 def setup_decoder(self) -> Decoder:
303 """Setup the decoder."""
304 # Diffusion decoder
305 if self.refine:
306 self.diffusion_decoder = DiffusionDecoder(model=self.diffusion_model)
307 else:
308 self.diffusion_decoder = None # type: ignore
310 # Transformer decoder
311 self.transformer_decoder = super().setup_decoder()
312 return self.transformer_decoder # type: ignore
314 def _tokenize_and_pad(self, refinement: list[str]) -> torch.Tensor:
315 """Tokenize and pad the transformer predictions."""
316 encodings = []
317 for refine in refinement:
318 refine_tokenized = self.diffusion_residue_set.tokenize(refine)
320 refine_encoding = self.diffusion_residue_set.encode(refine_tokenized, add_eos=False, return_tensor="pt")
322 refine_encoding = refine_encoding[: self.diffusion_max_length]
324 # Diffusion always padded to fixed length
325 refine_padded = torch.full(
326 (max(self.diffusion_max_length, refine_encoding.shape[0]),),
327 fill_value=self.diffusion_residue_set.PAD_INDEX,
328 dtype=refine_encoding.dtype,
329 device=refine_encoding.device,
330 )
331 refine_padded[: refine_encoding.shape[0]] = refine_encoding
333 encodings.append(refine_padded)
335 encodings, _ = DiffusionDataProcessor._pad_and_mask(encodings)
336 return encodings
338 def get_predictions(self, batch: Any) -> dict[str, Any]:
339 """Get the predictions for a batch."""
340 # set self.model to use the correct model
341 self.decoder = self.transformer_decoder
342 transformer_predictions = super().get_predictions(batch)
344 if not self.refine:
345 return transformer_predictions # type: ignore
347 batch[REFINEMENT_COLUMN] = self._tokenize_and_pad(transformer_predictions["predictions"]).to(self.accelerator.device)
348 batch[REFINEMENT_PROBABILITY_COLUMN] = torch.tensor(transformer_predictions["prediction_log_probability"]).to(self.accelerator.device)
350 self.decoder = self.diffusion_decoder # type: ignore
351 diffusion_predictions = self.diffusion_get_predictions(batch) # type: ignore
353 predictions = {
354 "predictions": diffusion_predictions["predictions"],
355 "prediction_log_probability": diffusion_predictions["prediction_log_probability"],
356 "prediction_token_log_probabilities": diffusion_predictions["prediction_token_log_probabilities"],
357 "targets": transformer_predictions["targets"],
358 }
360 # Don't need to keep these
361 transformer_predictions.pop("targets")
362 diffusion_predictions.pop("targets")
364 predictions.update({f"instanovo_{k}": v for k, v in transformer_predictions.items()})
365 predictions.update({f"instanovoplus_{k}": v for k, v in diffusion_predictions.items()})
367 return predictions