Coverage for instanovo/inference/greedy_search.py: 85%
137 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Literal
5import torch
6from jaxtyping import Float
8from instanovo.__init__ import console
9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, MASS_SCALE, PrecursorDimension
10from instanovo.inference.interfaces import Decodable, Decoder
11from instanovo.types import PrecursorFeatures, Spectrum
12from instanovo.utils.colorlogging import ColorLog
14logger = ColorLog(console, __name__).logger
17class GreedyDecoder(Decoder):
18 """A class for decoding from de novo sequence models using greedy search.
20 This class conforms to the `Decoder` interface and decodes from
21 models that conform to the `Decodable` interface.
22 """
24 def __init__(
25 self,
26 model: Decodable,
27 suppressed_residues: list[str] | None = None,
28 mass_scale: int = MASS_SCALE,
29 disable_terminal_residues_anywhere: bool = True,
30 float_dtype: torch.dtype = torch.float64,
31 ):
32 super().__init__(model=model)
33 self.mass_scale = mass_scale
34 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere
35 self.float_dtype = float_dtype
37 suppressed_residues = suppressed_residues or []
39 # NOTE: Greedy search requires `residue_set` class in the model,
40 # update all methods accordingly.
41 if not hasattr(model, "residue_set"):
42 raise AttributeError("The model is missing the required attribute: residue_set")
44 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000
45 # We would need to divide the scaled masses as we use floating point masses.
46 # These residue masses are per amino acid and include special tokens,
47 # special tokens have a mass of 0.
48 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype)
49 terminal_residues_idx: list[int] = []
50 suppressed_residues_idx: list[int] = []
52 # residue_target_offsets supports negative masses (overshoot the remaining mass)
53 # This fixes a bug where the residue prior to a negative mass residue is always invalid.
54 residue_target_offsets: list[float] = [0.0]
56 for i, residue in enumerate(model.residue_set.vocab):
57 if residue in self.model.residue_set.special_tokens:
58 continue
59 self.residue_masses[i] = self.model.residue_set.get_mass(residue)
60 # If no residue is attached, assume it is a n-terminal residue
61 if not residue[0].isalpha():
62 terminal_residues_idx.append(i)
63 if self.residue_masses[i] < 0:
64 residue_target_offsets.append(self.residue_masses[i])
66 # Check if residue is suppressed
67 if residue in suppressed_residues:
68 suppressed_residues_idx.append(i)
69 suppressed_residues.remove(residue)
71 if len(suppressed_residues) > 0:
72 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}")
74 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long)
75 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long)
76 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype)
78 self.vocab_size = len(self.model.residue_set)
80 def decode( # type:ignore
81 self,
82 spectra: Float[Spectrum, " batch"],
83 precursors: Float[PrecursorFeatures, " batch"],
84 max_length: int,
85 mass_tolerance: float = 5e-5,
86 max_isotope: int = 1,
87 min_log_prob: float = -float("inf"),
88 return_encoder_output: bool = False,
89 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean",
90 **kwargs,
91 ) -> dict[str, Any]:
92 """Decode predicted residue sequence for a batch of spectra using greedy search.
94 Args:
95 spectra (torch.FloatTensor):
96 The spectra to be sequenced.
98 precursors (torch.FloatTensor[batch size, 3]):
99 The precursor mass, charge and mass-to-charge ratio.
101 max_length (int):
102 The maximum length of a residue sequence.
104 mass_tolerance (float):
105 The maximum relative error for which a predicted sequence
106 is still considered to have matched the precursor mass.
108 max_isotope (int):
109 The maximum number of additional neutrons for isotopes
110 whose mass a predicted sequence's mass is considered
111 when comparing to the precursor mass.
113 All additional nucleon numbers from 1 to `max_isotope` inclusive
114 are considered.
116 min_log_prob (float):
117 Minimum log probability to stop decoding early. If a sequence
118 probability is less than this value it is marked as complete.
119 Defaults to -inf.
121 return_encoder_output:
122 Whether to return the encoder output.
124 encoder_output_reduction:
125 The reduction to apply to the encoder output.
126 Valid values are "mean", "max", "sum", "full".
127 Defaults to "mean".
129 Returns:
130 dict[str, Any]:
131 Required keys:
132 - "predictions": list[list[str]]
133 - "mass_error": list[float]
134 - "prediction_log_probability": list[float]
135 - "prediction_token_log_probabilities": list[list[float]]
136 - "encoder_output": list[float] (optional)
137 Example additional keys:
138 - "prediction_beam_0": list[str]
139 """
140 # Greedy search with precursor mass termination condition
141 batch_size = spectra.shape[0]
142 device = spectra.device
144 # Masses of all residues in vocabulary, 0 for special tokens
145 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, )
147 # ppm equivalent of mass tolerance
148 delta_ppm_tol = mass_tolerance * 10**6 # float (1, )
150 # Residue masses expanded (repeated) across batch_size
151 # This is used to quickly compute all possible remaining masses per vocab entry
152 residue_mass_delta = self.residue_masses.expand(batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size)
154 with torch.no_grad():
155 # 1. Compute spectrum encoding and masks
156 # Encoder is only run once.
157 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors)
159 # 2. Initialise beams and other variables
160 # The sequences decoded so far, grows on index 1 for every decoding pass.
161 # sequence_length is variable!
162 sequences = torch.zeros((batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length)
164 # Log probabilities of the sequences decoded so far,
165 # token probabilities are added at each step.
166 log_probabilities = torch.zeros((batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1)
168 # Keeps track of which beams are completed, this allows the model to skip these
169 complete_beams = torch.zeros((batch_size), device=device, dtype=bool) # bool (batch_size, )
171 # Keeps track of which stopped early or terminated with a bad stop condition.
172 # These predictions will be deleted.
173 # bad_stop_condition =
174 # torch.zeros((batch_size), device=device, dtype=bool) # bool (batch_size, )
175 # Extract precursor mass from `precursors`
176 precursor_mass = precursors[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, )
178 # Target mass delta, remaining mass x must be within `target > x > -target`.
179 # This target can shift with isotopes.
180 # Mass targets = error_ppm * m_prec * 1e-6
181 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, )
183 # This keeps track of the remaining mass budget for the currently decoding sequence,
184 # starts at the precursor - H2O
185 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, )
187 # TODO: only check when close to precursor mass? Might not be worth the overhead.
188 # Idea is if remaining < check_zone, we do the valid mass and complete checks.
189 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta
191 # Store token probabilities
192 token_log_probabilities = [] # list[list(float)] (sequence_length, batch_size)
194 # Start decoding
195 for _ in range(max_length):
196 # If all beams are complete, we can stop early.
197 if complete_beams.all():
198 break
200 # We only run the model on incomplete beams, note: we have to expand to
201 # the full batch size afterwards.
202 minibatch = (x[~complete_beams] for x in (sequences, precursors, spectrum_encoding, spectrum_mask))
203 # Keep track of how large the minibatch is
204 sub_batch_size = (~complete_beams).sum()
206 # Step 3: score the next tokens
207 # NOTE: SOS token is appended automatically in `score_candidates`.
208 # We do not have to add it.
209 next_token_probabilities = self.model.score_candidates(*minibatch)
211 # Step 4: Filter probabilities
212 # If remaining mass is within tolerance, we force an EOS token.
213 # All tokens that would set the remaining mass below the minimum
214 # cutoff `-mass_target_delta` including isotopes is set to -inf
216 # Step 4.1: Check if remaining mass is within tolerance:
217 # To keep it efficient we compute some of the indexed variables first:
218 remaining_mass_incomplete = remaining_mass[~complete_beams] # float64 (sub_batch_size, )
219 mass_target_incomplete = mass_target_delta[~complete_beams] # float64 (sub_batch_size, )
221 # remaining_meets_precursor =
222 # (remaining_mass[~complete_beams] < mass_target_delta[~complete_beams])
223 remaining_meets_precursor = torch.zeros((sub_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, )
224 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive)
225 for j in range(0, max_isotope + 1, 1):
226 # TODO: Use vectorized approach for this
227 isotope = CARBON_MASS_DELTA * j # float
228 remaining_lesser_isotope = remaining_mass_incomplete - isotope < mass_target_incomplete # bool (sub_batch_size, )
229 remaining_greater_isotope = remaining_mass_incomplete - isotope > -mass_target_incomplete # bool (sub_batch_size, )
231 # remaining mass is within the target tolerance
232 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, )
233 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, )
234 if remaining_within_range.any() and j > 0:
235 # If we did hit an isotope, correct the remaining mass accordingly
236 # TODO check this
237 remaining_mass_incomplete[remaining_within_range] = remaining_mass_incomplete[remaining_within_range] - isotope
239 # Step 4.2: Check which residues are valid
240 # Expand incomplete remaining mass across vocabulary size
241 remaining_mass_incomplete_expanded = remaining_mass_incomplete[:, None].expand(
242 sub_batch_size, self.vocab_size
243 ) # float64 (sub_batch_size, vocab_size)
244 mass_target_incomplete_expanded = mass_target_incomplete[:, None].expand(
245 sub_batch_size, self.vocab_size
246 ) # float64 (sub_batch_size, vocab_size)
247 residue_mass_delta_incomplete = residue_mass_delta[~complete_beams] # float64 (sub_batch_size, vocab_size)
249 valid_mass = (
250 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete > -mass_target_incomplete_expanded
251 ) # bool (sub_batch_size, vocab_size)
252 # Check all isotopes for valid masses
253 # TODO: Use vectorized approach for this
254 for mass_offset in self.residue_target_offsets:
255 for j in range(0, max_isotope + 1, 1):
256 isotope = CARBON_MASS_DELTA * j # float
257 mass_lesser_isotope = (
258 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete
259 < mass_target_incomplete_expanded + isotope + mass_offset
260 ) # bool (sub_batch_size, vocab_size)
261 mass_greater_isotope = (
262 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete
263 > -mass_target_incomplete_expanded + isotope + mass_offset
264 ) # bool (sub_batch_size, vocab_size)
265 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (sub_batch_size, vocab_size)
267 # Filtered probabilities:
268 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (sub_batch_size, vocab_size)
269 # If mass is invalid, set log_prob to -inf
270 next_token_probabilities_filtered[~valid_mass] = -float("inf")
271 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf")
272 # Allow the model to predict PAD when all residues are -inf
273 # next_token_probabilities_filtered[
274 # :, self.model.residue_set.PAD_INDEX
275 # ] = -float("inf")
276 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf")
277 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf")
278 # Set probability of n-terminal modifications to -inf when i > 0
279 if self.disable_terminal_residues_anywhere:
280 # Check if adding terminal residues would result in a complete sequence
281 # First generate remaining mass matrix with isotopes
282 remaining_mass_incomplete_isotope = remaining_mass_incomplete[:, None].expand(
283 sub_batch_size, max_isotope + 1
284 ) - CARBON_MASS_DELTA * (torch.arange(max_isotope + 1, device=device))
285 # Expand with terminal residues and subtract
286 remaining_mass_incomplete_isotope_delta = (
287 remaining_mass_incomplete_isotope[:, :, None].expand(
288 sub_batch_size,
289 max_isotope + 1,
290 self.terminal_residue_indices.shape[0],
291 )
292 - self.residue_masses[self.terminal_residue_indices]
293 )
295 # If within target delta, allow these residues to be predicted,
296 # otherwise set probability to -inf
297 allow_terminal = (remaining_mass_incomplete_isotope_delta.abs() < mass_target_incomplete[:, None, None]).any(dim=1)
298 allow_terminal_full = torch.ones(
299 (sub_batch_size, self.vocab_size),
300 device=spectra.device,
301 dtype=bool,
302 )
303 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal
305 # Set to -inf
306 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf")
308 # Step 5: Select next token:
309 next_token = next_token_probabilities_filtered.argmax(-1).unsqueeze(1) # long (sub_batch_size, 1)
310 next_token[remaining_meets_precursor] = self.model.residue_set.EOS_INDEX
312 # Update sequences
313 next_token_full = torch.full(
314 (batch_size, 1),
315 fill_value=self.model.residue_set.PAD_INDEX,
316 device=spectra.device,
317 dtype=sequences.dtype,
318 ) # long (batch_size, 1)
319 next_token_full[~complete_beams] = next_token
320 sequences = torch.concat([sequences, next_token_full], axis=1) # long (batch_size, 1)
322 # Expand and update masses
323 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, )
324 next_masses_full = torch.zeros((batch_size), device=spectra.device, dtype=remaining_mass.dtype) # float64 (batch_size, )
325 next_masses_full[~complete_beams] = next_masses
326 remaining_mass = remaining_mass - next_masses_full # float64 (batch_size, )
328 # Expand and update probabilities
329 next_probabilities = torch.gather(next_token_probabilities, 1, next_token)
330 next_probabilities_full = torch.zeros(
331 (batch_size, 1),
332 device=spectra.device,
333 dtype=log_probabilities.dtype,
334 )
335 next_probabilities_full[~complete_beams] = next_probabilities
336 log_probabilities = log_probabilities + next_probabilities_full
337 token_log_probabilities.append(next_probabilities_full[:, 0])
339 # Step 6: Terminate complete beams
341 # Check if complete:
342 # Early stopping if beam log probability below threshold
343 beam_confidence_filter = log_probabilities[~complete_beams, 0] < min_log_prob
344 # Stop if beam is forced to output an EOS
345 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX
346 next_is_complete = next_token_is_eos | beam_confidence_filter
348 # Check for a bad stop
349 # bad_stop_condition = beam_confidence_filter
350 # bad_stop_condition_full = torch.zeros((batch_size,),
351 # device=spectra.device, dtype=bad_stop_condition.dtype)
352 # bad_stop_condition_full[~complete_beams] = bad_stop_condition
353 # bad_stop_condition = bad_stop_condition | bad_stop_condition_full
355 # Expand and update complete beams
356 next_is_complete_full = torch.zeros((batch_size,), device=spectra.device, dtype=complete_beams.dtype)
357 next_is_complete_full[~complete_beams] = next_is_complete
358 complete_beams = complete_beams | next_is_complete_full
360 # Repeat from step 3.
362 all_log_probabilities = torch.stack(token_log_probabilities, axis=1)
364 # Example of new output format
365 result: dict[str, Any] = {
366 "predictions": [],
367 # "mass_error": [],
368 "prediction_log_probability": [],
369 "prediction_token_log_probabilities": [],
370 }
372 for i in range(batch_size):
373 sequence = self.model.decode(sequences[i])
374 result["predictions"].append(sequence)
375 # result["mass_error"].append(remaining_mass[i].item())
376 result["prediction_log_probability"].append(log_probabilities[i, 0].item())
377 result["prediction_token_log_probabilities"].append([x.cpu().item() for x in all_log_probabilities[i, : len(sequence)]][::-1])
379 if return_encoder_output:
380 # TODO: Refactor to use the same logic across all decoders
381 # Reduce along sequence length dimension
382 encoder_output = spectrum_encoding.float().cpu()
383 encoder_mask = (1 - spectrum_mask.float()).cpu()
384 encoder_output = encoder_output * encoder_mask.unsqueeze(-1)
385 if encoder_output_reduction == "mean":
386 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1)
387 encoder_output = encoder_output.sum(dim=1) / count
388 elif encoder_output_reduction == "max":
389 encoder_output[encoder_output == 0] = -float("inf")
390 encoder_output = encoder_output.max(dim=1)[0]
391 elif encoder_output_reduction == "sum":
392 encoder_output = encoder_output.sum(dim=1)
393 elif encoder_output_reduction == "full":
394 raise NotImplementedError("Full encoder output reduction is not yet implemented")
395 else:
396 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}")
397 result["encoder_output"] = list(encoder_output.numpy())
399 return result