Coverage for instanovo/inference/diffusion.py: 74%
101 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Literal, Optional
5import torch
6from jaxtyping import Bool, Float, Integer
7from torch.distributions import Categorical
9from instanovo.constants import CARBON_MASS_DELTA, DIFFUSION_EVAL_STEPS, DIFFUSION_START_STEP, H2O_MASS, PrecursorDimension
10from instanovo.diffusion.multinomial_diffusion import DiffusionLoss, InstaNovoPlus
11from instanovo.inference.interfaces import Decoder
12from instanovo.types import Peptide, PrecursorFeatures, Spectrum, SpectrumMask
15class DiffusionDecoder(Decoder):
16 """Class for decoding from a diffusion model by forward sampling."""
18 def __init__(self, model: InstaNovoPlus) -> None:
19 super().__init__(model=model)
20 # Override base class type annotation - this is actually InstaNovoPlus, not just Decodable
21 self.model: InstaNovoPlus = model
23 self.time_steps = self.model.time_steps
24 self.residue_set = self.model.residue_set
25 self.loss_function = DiffusionLoss(model=self.model)
27 def decode(
28 self,
29 spectra: Float[Spectrum, " batch"],
30 precursors: Float[PrecursorFeatures, " batch"],
31 spectra_padding_mask: Bool[SpectrumMask, " batch"],
32 initial_sequence: Optional[Integer[Peptide, " batch"]] = None,
33 start_step: int = DIFFUSION_START_STEP,
34 eval_steps: tuple[int, ...] = DIFFUSION_EVAL_STEPS,
35 beam_size: int = 1,
36 mass_tolerance: float = 5e-5,
37 max_isotope: int = 1,
38 return_encoder_output: bool = False,
39 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean",
40 return_beam: bool = False,
41 **kwargs: Any,
42 ) -> dict[str, Any]:
43 """Decoding predictions from a diffusion model by forward sampling.
45 Args:
46 spectra:
47 A batch of spectra to be decoded.
49 spectra_padding_mask:
50 Padding mask for a batch of variable length spectra.
52 precursors:
53 Precursor mass, charge and m/z for a batch of spectra.
55 initial_sequence:
56 An initial sequence for the model to refine. If no initial sequence is
57 provided (the value is None), will sample a random sequence from a uniform unigram
58 model. Defaults to None.
60 start_step:
61 The step at which to insert the initial sequence and start refinement. If
62 `initial_sequence` is not provided, this will be set to `time_steps - 1`.
64 eval_steps:
65 The steps at which to evaluate the loss and compute the log-probabilities.
67 return_encoder_output:
68 Whether to return the encoder output.
70 encoder_output_reduction:
71 The reduction to apply to the encoder output.
72 Valid values are "mean", "max", "sum", "full".
73 Defaults to "mean".
75 Returns:
76 dict[str, Any]:
77 The decoded peptides and their log-probabilities for a batch of spectra.
78 Required keys:
79 - "predictions": list[list[str]]
80 - "prediction_log_probability": list[float]
81 - "prediction_token_log_probabilities": list[list[float]]
82 - "encoder_output": list[float] (optional)
83 Example additional keys:
84 - "prediction_beam_0": list[str]
85 """
86 device = spectra.device
87 sequence_length = self.model.config.max_length
88 batch_size, num_classes = spectra.size(0), len(self.residue_set)
89 effective_batch_size = batch_size * beam_size
91 spectra_expanded = spectra.repeat_interleave(beam_size, dim=0)
92 spectra_padding_mask_expanded = spectra_padding_mask.repeat_interleave(beam_size, dim=0)
93 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0)
95 if initial_sequence is None:
96 # Sample uniformly
97 initial_distribution = Categorical(torch.ones(effective_batch_size, sequence_length, num_classes) / num_classes)
98 sample = initial_distribution.sample().to(device)
99 start_step = self.time_steps - 1
100 else:
101 sample = initial_sequence.repeat_interleave(beam_size, dim=0)
103 peptide_mask = torch.zeros(effective_batch_size, sequence_length).bool().to(device)
105 log_probs = torch.zeros((effective_batch_size, sequence_length)).to(device)
106 # Sample through reverse process
107 for t in range(start_step, -1, -1):
108 times = (t * torch.ones((effective_batch_size,))).long().to(spectra.device)
109 distribution = Categorical(
110 logits=self.model.reverse_distribution(
111 x_t=sample,
112 time=times,
113 spectra=spectra_expanded,
114 spectra_padding_mask=spectra_padding_mask_expanded,
115 precursors=precursors_expanded,
116 x_padding_mask=peptide_mask,
117 )
118 )
119 sample = distribution.sample()
121 # Calculate log-probabilities as average loss across `eval_steps`
122 losses = []
123 for t in eval_steps:
124 times = (t * torch.ones((effective_batch_size,))).long().to(spectra.device)
125 losses.append(
126 self.loss_function._compute_loss(
127 x_0=sample,
128 t=times,
129 spectra=spectra_expanded,
130 spectra_padding_mask=spectra_padding_mask_expanded,
131 precursors=precursors_expanded,
132 x_padding_mask=peptide_mask,
133 )
134 )
135 log_probs = (-torch.stack(losses).mean(axis=0).cpu()).tolist()
136 sequences = self._extract_predictions(sample)
138 # convert to batch_size, beam_size format
139 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)]
141 for idx in range(effective_batch_size):
142 batch_idx = idx // beam_size
144 sequence = sequences[idx]
145 sequence_mass = sum([self.model.residue_set.get_mass(residue) for residue in sequence])
146 # Check precursor matching
147 precursor_mass = precursors[batch_idx, PrecursorDimension.PRECURSOR_MASS.value]
148 remaining_mass = precursor_mass - sequence_mass - H2O_MASS
149 mass_target_delta = mass_tolerance * precursor_mass
151 # Check if mass within range, including isotopes
152 remaining_meets_precursor = False
153 for j in range(0, max_isotope + 1, 1):
154 isotope = CARBON_MASS_DELTA * j
155 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta
156 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta
157 remaining_meets_precursor = remaining_meets_precursor | (remaining_lesser_isotope & remaining_greater_isotope)
159 completed_beams[batch_idx].append(
160 {
161 "predictions": sequences[idx],
162 "meets_precursor": remaining_meets_precursor,
163 "prediction_log_probability": log_probs[idx],
164 "prediction_token_log_probabilities": log_probs[idx],
165 }
166 )
168 # Sort beams by meets_precursor and prediction_log_probability
169 # Index 0 is the best beam
170 for batch_idx in range(batch_size):
171 completed_beams[batch_idx].sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True)
173 result: dict[str, Any] = {
174 "predictions": [],
175 "meets_precursor": [],
176 "prediction_log_probability": [],
177 "prediction_token_log_probabilities": [None] * batch_size,
178 }
179 if return_beam:
180 for i in range(beam_size):
181 result[f"predictions_beam_{i}"] = []
182 result[f"prediction_log_probability_beam_{i}"] = []
183 result[f"prediction_token_log_probabilities_beam_{i}"] = ([None] * batch_size,)
185 for batch_idx in range(batch_size):
186 if return_beam:
187 for beam_idx in range(beam_size):
188 result[f"predictions_beam_{beam_idx}"].append(completed_beams[batch_idx][beam_idx]["predictions"])
189 result[f"prediction_log_probability_beam_{beam_idx}"].append(completed_beams[batch_idx][beam_idx]["prediction_log_probability"])
191 result["predictions"].append(completed_beams[batch_idx][0]["predictions"])
192 result["meets_precursor"].append(completed_beams[batch_idx][0]["meets_precursor"])
193 result["prediction_log_probability"].append(completed_beams[batch_idx][0]["prediction_log_probability"])
195 if return_encoder_output:
196 # Extract encoder output from cache
197 encoder_output = self.model.transition_model.cache_cond_emb
198 encoder_mask = self.model.transition_model.cache_cond_padding_mask
200 if encoder_output is None or encoder_mask is None:
201 raise ValueError("Could not extract encoder output from model to return as encoder output.")
203 encoder_output = encoder_output.float().cpu()
204 encoder_mask = (1 - encoder_mask.float()).cpu()
205 encoder_output = encoder_output * encoder_mask.unsqueeze(-1)
206 if encoder_output_reduction == "mean":
207 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1)
208 encoder_output = encoder_output.sum(dim=1) / count
209 elif encoder_output_reduction == "max":
210 encoder_output[encoder_output == 0] = -float("inf")
211 encoder_output = encoder_output.max(dim=1)[0]
212 elif encoder_output_reduction == "sum":
213 encoder_output = encoder_output.sum(dim=1)
214 elif encoder_output_reduction == "full":
215 raise NotImplementedError("Full encoder output reduction is not yet implemented")
216 else:
217 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}")
218 result["encoder_output"] = list(encoder_output.numpy())
220 return result
222 def _extract_predictions(self, sample: Integer[Peptide, " batch"]) -> list[list[str]]:
223 output = []
224 for sequence in sample:
225 tokens = sequence.tolist()
226 if self.residue_set.EOS_INDEX in sequence:
227 peptide = tokens[: tokens.index(self.residue_set.EOS_INDEX)]
228 else:
229 peptide = tokens
230 output.append(self.residue_set.decode(peptide, reverse=False)) # we do not reverse peptide for diffusion
231 return output