Coverage for instanovo/transformer/predict.py: 90%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import os
2from pathlib import Path
3from typing import Any, Tuple
5import torch
6import torch.nn as nn
7from omegaconf import DictConfig
9from instanovo.__init__ import console
10from instanovo.common import AccelerateDeNovoPredictor, DataProcessor
11from instanovo.constants import MASS_SCALE, MAX_MASS
12from instanovo.inference import (
13 BeamSearchDecoder,
14 Decoder,
15 GreedyDecoder,
16 Knapsack,
17 KnapsackBeamSearchDecoder,
18)
19from instanovo.transformer.data import TransformerDataProcessor
20from instanovo.transformer.model import InstaNovo
21from instanovo.utils.colorlogging import ColorLog
23logger = ColorLog(console, __name__).logger
25CONFIG_PATH = Path(__file__).parent.parent / "configs"
28class TransformerPredictor(AccelerateDeNovoPredictor):
29 """Predictor for the InstaNovo model."""
31 def __init__(
32 self,
33 config: DictConfig,
34 ) -> None:
35 self.num_beams = config.get("num_beams", 1)
36 # Change this to a generic save all outputs function
37 self.save_beams = config.get("save_beams", False)
38 super().__init__(config)
40 def load_model(self) -> Tuple[nn.Module, DictConfig]:
41 """Setup the model."""
42 default_model = InstaNovo.get_pretrained()[0]
43 model_path = self.config.get("instanovo_model", default_model)
45 logger.info(f"Loading InstaNovo model {model_path}")
46 if model_path in InstaNovo.get_pretrained():
47 # Using a pretrained model from models.json
48 model, model_config = InstaNovo.from_pretrained(
49 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None
50 )
51 else:
52 model_path = self.s3.get_local_path(model_path)
53 assert model_path is not None
54 model, model_config = InstaNovo.load(
55 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None
56 )
58 return model, model_config
60 def setup_data_processor(self) -> DataProcessor:
61 """Setup the data processor."""
62 processor = TransformerDataProcessor(
63 self.residue_set,
64 n_peaks=self.model_config.get("n_peaks", 200),
65 min_mz=self.model_config.get("min_mz", 50.0),
66 max_mz=self.model_config.get("max_mz", 2500.0),
67 min_intensity=self.model_config.get("min_intensity", 0.01),
68 remove_precursor_tol=self.model_config.get("remove_precursor_tol", 2.0),
69 return_str=False,
70 use_spectrum_utils=False,
71 annotated=not self.denovo,
72 metadata_columns=["group"],
73 )
75 return processor
77 def setup_decoder(self) -> Decoder:
78 """Setup the decoder."""
79 float_dtype = torch.float32 if self.config.get("force_fp32", False) else torch.float64
80 if self.config.get("use_knapsack", False):
81 logger.info(f"Using Knapsack Beam Search with {self.num_beams} beam(s)")
82 knapsack_path = self.config.get("knapsack_path", None)
83 if knapsack_path is None or not os.path.exists(knapsack_path):
84 logger.info("Knapsack path missing or not specified, generating...")
85 knapsack = _setup_knapsack(self.model, self.config.get("max_isotope_error", 1))
86 decoder: Decoder = KnapsackBeamSearchDecoder(self.model, knapsack, float_dtype=float_dtype)
87 if knapsack_path is not None:
88 logger.info(f"Saving knapsack to {knapsack_path}")
89 knapsack.save(knapsack_path)
90 else:
91 logger.info("Knapsack path found. Loading...")
92 decoder = KnapsackBeamSearchDecoder.from_file(self.model, knapsack_path, float_dtype=float_dtype)
93 elif self.num_beams > 1:
94 logger.info(f"Using Beam Search with {self.num_beams} beam(s)")
95 decoder = BeamSearchDecoder(self.model, float_dtype=float_dtype)
96 else:
97 logger.info(f"Using Greedy Search with {self.num_beams} beam(s)")
98 decoder = GreedyDecoder(
99 model=self.model,
100 suppressed_residues=self.config.get("suppressed_residues", None),
101 disable_terminal_residues_anywhere=self.config.get("disable_terminal_residues_anywhere", True),
102 float_dtype=float_dtype,
103 )
104 return decoder
106 def get_predictions(self, batch: Any) -> dict[str, Any]:
107 """Get the predictions for a batch."""
108 batch_size = batch["spectra"].size(0)
110 batch_predictions: dict[str, Any] = self.decoder.decode(
111 spectra=batch["spectra"],
112 precursors=batch["precursors"],
113 beam_size=self.num_beams,
114 max_length=self.config.get("max_length", 40),
115 return_beam=self.num_beams > 1,
116 return_encoder_output=self.save_encoder_outputs,
117 encoder_output_reduction=self.encoder_output_reduction,
118 )
120 if "peptides" in batch:
121 targets = [self.residue_set.decode(seq, reverse=True) for seq in batch["peptides"]]
122 else:
123 targets = [None] * batch_size
125 batch_predictions["targets"] = targets
127 return batch_predictions
130def _setup_knapsack(model: InstaNovo, max_isotope: int = 2) -> Knapsack:
131 residue_masses = dict(model.residue_set.residue_masses.copy())
132 for special_residue in list(model.residue_set.residue_to_index.keys())[:3]:
133 residue_masses[special_residue] = 0
134 residue_indices = model.residue_set.residue_to_index
135 return Knapsack.construct_knapsack(
136 residue_masses=residue_masses,
137 residue_indices=residue_indices,
138 max_mass=MAX_MASS,
139 mass_scale=MASS_SCALE,
140 max_isotope=max_isotope,
141 )