Coverage for instanovo/inference/knapsack_beam_search.py: 86%
247 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Literal
5import torch
6from jaxtyping import Float
8from instanovo.__init__ import console
9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, PrecursorDimension
10from instanovo.inference.interfaces import Decodable, Decoder
11from instanovo.inference.knapsack import Knapsack
12from instanovo.types import PrecursorFeatures, Spectrum
13from instanovo.utils.colorlogging import ColorLog
15logger = ColorLog(console, __name__).logger
18class KnapsackBeamSearchDecoder(Decoder):
19 """A class for decoding from de novo sequence models using beam search.
21 This class conforms to the `Decoder` interface and decodes from
22 models that conform to the `Decodable` interface.
23 """
25 def __init__(
26 self,
27 model: Decodable,
28 knapsack: Knapsack,
29 suppressed_residues: list[str] | None = None,
30 disable_terminal_residues_anywhere: bool = True,
31 keep_invalid_mass_sequences: bool = True,
32 float_dtype: torch.dtype = torch.float64,
33 ):
34 super().__init__(model=model)
35 self.knapsack = knapsack
36 self.chart = torch.tensor(self.knapsack.chart)
37 self.mass_scale = knapsack.mass_scale
38 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere
39 self.keep_invalid_mass_sequences = keep_invalid_mass_sequences
40 self.float_dtype = float_dtype
42 suppressed_residues = suppressed_residues or []
44 # NOTE: Greedy search requires `residue_set` class in the model,
45 # update all methods accordingly.
46 if not hasattr(model, "residue_set"):
47 raise AttributeError("The model is missing the required attribute: residue_set")
49 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000
50 # We would need to divide the scaled masses as we use floating point masses.
51 # These residue masses are per amino acid and include special tokens,
52 # special tokens have a mass of 0.
53 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype)
54 terminal_residues_idx: list[int] = []
55 suppressed_residues_idx: list[int] = []
57 # residue_target_offsets supports negative masses (overshoot the remaining mass)
58 # This fixes a bug where the residue prior to a negative mass residue is always invalid.
59 residue_target_offsets: list[float] = [0.0]
61 for i, residue in enumerate(model.residue_set.vocab):
62 if residue in self.model.residue_set.special_tokens:
63 continue
64 self.residue_masses[i] = self.model.residue_set.get_mass(residue)
65 # If no residue is attached, assume it is a n-terminal residue
66 if not residue[0].isalpha():
67 terminal_residues_idx.append(i)
68 if self.residue_masses[i] < 0:
69 residue_target_offsets.append(self.residue_masses[i])
71 # Check if residue is suppressed
72 if residue in suppressed_residues:
73 suppressed_residues_idx.append(i)
74 suppressed_residues.remove(residue)
76 if len(suppressed_residues) > 0:
77 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}")
79 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long)
80 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long)
81 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype)
83 self.vocab_size = len(self.model.residue_set)
85 @classmethod
86 def from_file(cls, model: Decodable, path: str, float_dtype: torch.dtype = torch.float64) -> KnapsackBeamSearchDecoder:
87 """Initialize a decoder by loading a saved knapsack.
89 Args:
90 model (Decodable): The model to be decoded from.
91 path (str): The path to the directory where the knapsack
92 was saved to.
93 float_dtype (torch.dtype): The floating point dtype to use.
95 Returns:
96 KnapsackBeamSearchDecoder: The decoder.
97 """
98 knapsack = Knapsack.from_file(path=path)
99 return cls(model=model, knapsack=knapsack, float_dtype=float_dtype)
101 def _filter_knapsack(
102 self,
103 remaining_mass: torch.Tensor,
104 mass_target_delta: torch.Tensor,
105 mass_target_expanded: torch.Tensor,
106 ) -> torch.Tensor:
107 self.chart = self.chart.to(remaining_mass.device)
108 # batch_size = remaining_mass.shape[0]
109 vocab_size = self.chart.shape[1]
111 # Step 1: Compute bounds
112 lower_bound = (remaining_mass - mass_target_delta).unsqueeze(1) # [batch_size, vocab_size]
113 upper_bound = (remaining_mass + mass_target_delta).unsqueeze(1) # [batch_size, vocab_size]
115 lower_bound = (lower_bound * self.mass_scale).round().long()
116 upper_bound = (upper_bound * self.mass_scale).round().long()
118 # Step 2: Clamp to valid chart index range
119 lower_bound = lower_bound.clamp(min=0)
120 upper_bound = upper_bound.clamp(min=0, max=self.chart.shape[0] - 1)
122 # Step 3: Compute maximum interval width
123 max_span = (upper_bound - lower_bound + 1).max().item()
125 # Step 4: Build mass indices
126 span_indices = torch.arange(max_span, device=remaining_mass.device).view(1, 1, -1) # [1, 1, max_span]
127 mass_indices = lower_bound.unsqueeze(-1) + span_indices # [batch_size, vocab_size, max_span]
129 # Step 5: Clamp again to ensure indexing is valid (some upper < lower may happen)
130 mass_indices = mass_indices.clamp(0, self.chart.shape[0] - 1)
132 # Step 6: Create span mask
133 span_lengths = (upper_bound - lower_bound + 1).unsqueeze(-1) # [batch_size, vocab_size, 1]
134 span_mask = span_indices - lower_bound.unsqueeze(-1) < span_lengths # mask out padded indices
136 # Step 7: Gather and mask
137 # chart[mass_indices, residue] = chart[mass_indices[b, r, :], r]
138 # => gather per (mass, vocab) with advanced indexing
139 chart_vals = self.chart[mass_indices, torch.arange(vocab_size).view(1, -1, 1)] # [batch_size, vocab_size, max_span]
141 chart_vals = chart_vals & span_mask # apply mask
143 # Step 8: Reduce
144 return chart_vals.any(dim=-1) # [batch_size, vocab_size]
146 def decode( # type:ignore
147 self,
148 spectra: Float[Spectrum, " batch"],
149 precursors: Float[PrecursorFeatures, " batch"],
150 beam_size: int,
151 max_length: int,
152 mass_tolerance: float = 5e-5,
153 max_isotope: int = 1,
154 min_log_prob: float = -float("inf"),
155 return_encoder_output: bool = False,
156 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean",
157 return_beam: bool = False,
158 **kwargs,
159 ) -> dict[str, Any]:
160 """Decode predicted residue sequence for a batch of spectra using beam search.
162 Args:
163 spectra (torch.FloatTensor):
164 The spectra to be sequenced.
166 precursors (torch.FloatTensor[batch size, 3]):
167 The precursor mass, charge and mass-to-charge ratio.
169 beam_size (int):
170 The maximum size of the beam.
171 Ignored in beam search.
173 max_length (int):
174 The maximum length of a residue sequence.
176 mass_tolerance (float):
177 The maximum relative error for which a predicted sequence
178 is still considered to have matched the precursor mass.
180 max_isotope (int):
181 The maximum number of additional neutrons for isotopes
182 whose mass a predicted sequence's mass is considered
183 when comparing to the precursor mass.
185 All additional nucleon numbers from 1 to `max_isotope` inclusive
186 are considered.
188 min_log_prob (float):
189 Minimum log probability to stop decoding early. If a sequence
190 probability is less than this value it is marked as complete.
191 Defaults to -inf.
193 return_beam (bool):
194 Optionally return beam-search results. Ignored in greedy search.
196 Returns:
197 list[list[str]]:
198 The predicted sequence as a list of residue tokens.
199 This method will return an empty list for each
200 spectrum in the batch where
201 decoding fails i.e. no sequence that fits the precursor mass
202 to within a tolerance is found.
203 """
204 # Greedy search with precursor mass termination condition
205 batch_size = spectra.shape[0]
206 effective_batch_size = batch_size * beam_size
207 device = spectra.device
209 # Masses of all residues in vocabulary, 0 for special tokens
210 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, )
212 # ppm equivalent of mass tolerance
213 delta_ppm_tol = mass_tolerance * 10**6 # float (1, )
215 # Residue masses expanded (repeated) across batch_size
216 # This is used to quickly compute all possible remaining masses per vocab entry
217 residue_mass_delta = self.residue_masses.expand(effective_batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size)
219 # completed_items: list[list[ScoredSequence]] = [[] for _ in range(batch_size)]
220 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)]
222 with torch.no_grad():
223 # 1. Compute spectrum encoding and masks
224 # Encoder is only run once.
225 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors)
227 # EXPAND FOR BEAM SIZE
228 spectrum_encoding_expanded = spectrum_encoding.repeat_interleave(beam_size, dim=0)
229 spectrum_mask_expanded = spectrum_mask.repeat_interleave(beam_size, dim=0)
231 # 2. Initialise beams and other variables
232 # The sequences decoded so far, grows on index 1 for every decoding pass.
233 # sequence_length is variable!
234 sequences = torch.zeros((effective_batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length)
236 # Log probabilities of the sequences decoded so far,
237 # token probabilities are added at each step.
238 log_probabilities = torch.zeros((effective_batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1)
240 # Keeps track of which beams are completed, this allows the model to skip these
241 complete_beams = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, )
242 is_first_complete = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, )
244 # Extract precursor mass from `precursors`
245 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0)
247 precursor_mass = precursors_expanded[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, )
249 # Target mass delta, remaining mass x must be within `target > x > -target`.
250 # This target can shift with isotopes.
251 # Mass targets = error_ppm * m_prec * 1e-6
252 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, )
254 # This keeps track of the remaining mass budget for the currently decoding sequence,
255 # starts at the precursor - H2O
256 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, )
258 # TODO: only check when close to precursor mass? Might not be worth the overhead.
259 # Idea is if remaining < check_zone, we do the valid mass and complete checks.
260 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta
262 # Constant beam indices for retaining beams on failed decoding
263 constant_beam_indices = torch.arange(beam_size, device=device)[None, :].repeat_interleave(batch_size, dim=0)
265 # Store token probabilities
266 token_log_probabilities: dict[str, list[float]] = {} # dict[str, list[float]]
268 # Start decoding
269 for i in range(max_length):
270 # If all beams are complete, we can stop early.
271 if complete_beams.all():
272 break
274 # Step 3: score the next tokens
275 # NOTE: SOS token is appended automatically in `score_candidates`.
276 # We do not have to add it.
277 batch = (sequences, precursors_expanded, spectrum_encoding_expanded, spectrum_mask_expanded)
278 next_token_probabilities = self.model.score_candidates(*batch)
280 # Step 4: Filter probabilities
281 # If remaining mass is within tolerance, we force an EOS token.
282 # All tokens that would set the remaining mass below the minimum
283 # cutoff `-mass_target_delta` including isotopes is set to -inf
285 # Step 4.1: Check if remaining mass is within tolerance:
286 # To keep it efficient we compute some of the indexed variables first:
287 remaining_meets_precursor = torch.zeros((effective_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, )
288 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive)
289 for j in range(0, max_isotope + 1, 1):
290 # TODO: Use vectorized approach for this
291 isotope = CARBON_MASS_DELTA * j # float
292 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta # bool (sub_batch_size, )
293 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta # bool (sub_batch_size, )
295 # remaining mass is within the target tolerance
296 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, )
297 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, )
298 if remaining_within_range.any() and j > 0:
299 # If we did hit an isotope, correct the remaining mass accordingly
300 # TODO check this
301 remaining_mass[remaining_within_range] = remaining_mass[remaining_within_range] - isotope
303 # Step 4.2: Check which residues are valid
304 # Expand incomplete remaining mass across vocabulary size
305 remaining_mass_expanded = remaining_mass[:, None].expand(
306 effective_batch_size, self.vocab_size
307 ) # float64 (effective_batch_size, vocab_size)
308 mass_target_expanded = mass_target_delta[:, None].expand(
309 effective_batch_size, self.vocab_size
310 ) # float64 (effective_batch_size, vocab_size)
312 valid_mass = remaining_mass_expanded - residue_mass_delta > -mass_target_expanded # bool (effective_batch_size, vocab_size)
313 # Check all isotopes for valid masses
314 # TODO: Use vectorized approach for this
315 for mass_offset in self.residue_target_offsets:
316 for j in range(0, max_isotope + 1, 1):
317 isotope = CARBON_MASS_DELTA * j # float
318 mass_lesser_isotope = (
319 remaining_mass_expanded - residue_mass_delta < mass_target_expanded + isotope + mass_offset
320 ) # bool (effective_batch_size, vocab_size)
321 mass_greater_isotope = (
322 remaining_mass_expanded - residue_mass_delta > -mass_target_expanded + isotope + mass_offset
323 ) # bool (effective_batch_size, vocab_size)
324 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (effective_batch_size, vocab_size)
326 # Filter using knapsack
327 knapsack_valid_mass = self._filter_knapsack(
328 remaining_mass=remaining_mass,
329 mass_target_delta=mass_target_delta,
330 mass_target_expanded=mass_target_expanded,
331 )
332 # knapsack_valid_mass = torch.zeros_like(valid_mass)
333 # for mass_offset in self.residue_target_offsets:
334 # for j in range(0, max_isotope + 1, 1):
335 # isotope = CARBON_MASS_DELTA * j # float
336 # knapsack_valid_mass = knapsack_valid_mass | self._filter_knapsack(
337 # remaining_mass=remaining_mass - isotope - mass_offset,
338 # mass_target_delta=mass_target_delta,
339 # mass_target_expanded=mass_target_expanded,
340 # )
342 valid_mass = valid_mass & knapsack_valid_mass
344 # Filtered probabilities:
345 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (effective_batch_size, vocab_size)
346 # If mass is invalid, set log_prob to -inf
347 next_token_probabilities_filtered[~valid_mass] = -float("inf")
349 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf")
350 # Allow the model to predict PAD when all residues are -inf
351 next_token_probabilities_filtered[:, self.model.residue_set.PAD_INDEX] = -float("inf")
352 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf")
353 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf")
354 # Set probability of n-terminal modifications to -inf when i > 0
355 if self.disable_terminal_residues_anywhere:
356 # Check if adding terminal residues would result in a complete sequence
357 # First generate remaining mass matrix with isotopes
358 remaining_mass_isotope = remaining_mass[:, None].expand(effective_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * (
359 torch.arange(max_isotope + 1, device=device)
360 )
361 # Expand with terminal residues and subtract
362 remaining_mass_isotope_delta = (
363 remaining_mass_isotope[:, :, None].expand(
364 effective_batch_size,
365 max_isotope + 1,
366 self.terminal_residue_indices.shape[0],
367 )
368 - self.residue_masses[self.terminal_residue_indices]
369 )
371 # If within target delta, allow these residues to be predicted,
372 # otherwise set probability to -inf
373 allow_terminal = (remaining_mass_isotope_delta.abs() < mass_target_delta[:, None, None]).any(dim=1)
374 allow_terminal_full = torch.ones(
375 (effective_batch_size, self.vocab_size),
376 device=spectra.device,
377 dtype=bool,
378 )
379 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal
381 # Set to -inf
382 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf")
384 # This doesn't need to go here, we can do it at the end
385 if is_first_complete.any():
386 completed_idxs = is_first_complete.nonzero().squeeze(-1)
387 for beam_idx in completed_idxs:
388 sequence_probability = (
389 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx, self.model.residue_set.EOS_INDEX]
390 )
391 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()])
392 sequence = self.model.decode(sequences[beam_idx])
393 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]}
394 if "".join(sequence) in seen_completed_sequences:
395 continue
396 completed_beams[beam_idx // beam_size].append(
397 {
398 "predictions": sequence,
399 "mass_error": remaining_mass[beam_idx].item(),
400 "meets_precursor": remaining_meets_precursor[beam_idx].item(),
401 "prediction_log_probability": sequence_probability.item(),
402 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
403 }
404 )
406 # For beams that already meet precursor, -inf them and force an EOS
407 next_token_probabilities_filtered[remaining_meets_precursor, :] = -float("inf")
408 if self.keep_invalid_mass_sequences:
409 # Allow EOS on beams that dont fit precursor
410 next_beam_no_predictions = next_token_probabilities_filtered.isinf().all(-1)
411 allow_eos = (remaining_meets_precursor | next_beam_no_predictions) & ~complete_beams
412 else:
413 allow_eos = (remaining_meets_precursor) & ~complete_beams
414 next_eos_probs = next_token_probabilities[allow_eos, self.model.residue_set.EOS_INDEX]
415 next_token_probabilities_filtered[allow_eos, self.model.residue_set.EOS_INDEX] = next_eos_probs
417 # Step 5: Select next token:
418 log_probabilities_expanded = log_probabilities.repeat_interleave(self.vocab_size, dim=1)
419 log_probabilities_expanded = log_probabilities_expanded + next_token_probabilities_filtered
421 log_probabilities_beams = log_probabilities_expanded.view(-1, beam_size, self.vocab_size)
422 if i == 0 and beam_size > 1:
423 # Nullify all beams except the first one
424 # Forces divergence of beams at the start
425 log_probabilities_beams[:, 1:] = -float("inf")
427 log_probabilities_beams = log_probabilities_beams.view(-1, beam_size * self.vocab_size)
429 topk_values, topk_indices = log_probabilities_beams.topk(beam_size, dim=-1)
430 topk_is_inf = topk_values.isinf()
431 beam_indices = topk_indices // self.vocab_size
432 # Retain beams on failed decoding (when all beams are -inf)
433 beam_indices[topk_is_inf] = constant_beam_indices[topk_is_inf]
434 beam_indices_full = (beam_indices + torch.arange(batch_size, device=beam_indices.device)[:, None] * beam_size).view(-1)
436 next_token = topk_indices % self.vocab_size
437 next_token[topk_is_inf] = self.model.residue_set.PAD_INDEX
439 next_token = next_token.view(-1, 1) # long (sub_batch_size, 1)
441 # Update beams by indices
442 sequences = sequences[beam_indices_full]
443 log_probabilities = log_probabilities[beam_indices_full]
444 next_token_probabilities = next_token_probabilities[beam_indices_full]
445 remaining_mass = remaining_mass[beam_indices_full]
446 complete_beams = complete_beams[beam_indices_full]
448 sequences = torch.concat([sequences, next_token], axis=1) # long (batch_size, 1)
450 # Expand and update masses
451 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, )
452 remaining_mass = remaining_mass - next_masses # float64 (batch_size, )
454 # Expand and update probabilities
455 next_token_probabilities[:, self.model.residue_set.PAD_INDEX] = 0
456 next_probabilities = torch.gather(next_token_probabilities, 1, next_token)
457 next_probabilities[complete_beams] = 0
459 log_probabilities = log_probabilities + next_probabilities
461 for batch_index in range(effective_batch_size):
462 # Create unique ID for the sequence
463 # Store beam token probabilities in a hash table
464 spectrum_index = batch_index // beam_size
465 sequence = [str(x) for x in sequences[batch_index].cpu().tolist()]
466 sequence_str = str(spectrum_index) + "-" + ".".join(sequence)
467 sequence_prev_str = str(spectrum_index) + "-" + ".".join(sequence[:-1])
469 if sequence_prev_str in token_log_probabilities:
470 previous_probabilities = list(token_log_probabilities[sequence_prev_str])
471 else:
472 previous_probabilities = []
474 previous_probabilities.append(next_probabilities[batch_index, 0].float().item())
476 token_log_probabilities[sequence_str] = previous_probabilities
478 # Step 6: Terminate complete beams
480 # Check if complete:
481 # Early stopping if beam log probability below threshold
482 beam_confidence_filter = log_probabilities[:, 0] < min_log_prob
483 # Stop if beam is forced to output an EOS
484 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX
485 next_token_is_pad = next_token[:, 0] == self.model.residue_set.PAD_INDEX
486 next_is_complete = next_token_is_eos | beam_confidence_filter # | next_token_is_pad
488 # Keep track of which beams have completed
489 # Beams that complete for the first time are added to completed_beams
490 complete_beams = complete_beams | next_is_complete
491 is_first_complete = next_is_complete
493 if next_token_is_pad.all():
494 break
496 # Repeat from step 3.
498 # Check if any beams are complete at the end of the loop
499 if is_first_complete.any():
500 completed_idxs = is_first_complete.nonzero().squeeze(-1)
501 for beam_idx in completed_idxs:
502 sequence_probability = (
503 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx,
504 # self.model.residue_set.EOS_INDEX]
505 )
506 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()])
507 sequence = self.model.decode(sequences[beam_idx])
508 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]}
509 if "".join(sequence) in seen_completed_sequences:
510 continue
511 completed_beams[beam_idx // beam_size].append(
512 {
513 "predictions": sequence,
514 "mass_error": remaining_mass[beam_idx].item(),
515 "meets_precursor": remaining_meets_precursor[beam_idx].item(),
516 "prediction_log_probability": sequence_probability.item(),
517 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
518 }
519 )
521 # This loop forcefully adds all beams at the end, whether they are complete or not
522 if self.keep_invalid_mass_sequences:
523 for batch_idx in range(effective_batch_size):
524 sequence_str = str(batch_idx // beam_size) + "-" + ".".join([str(x) for x in sequences[batch_idx].cpu().tolist()])
525 sequence = self.model.decode(sequences[batch_idx])
526 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[batch_idx // beam_size]}
527 if "".join(sequence) in seen_completed_sequences:
528 # print(f"Skipping {sequence_str} because it is added")
529 continue
530 completed_beams[batch_idx // beam_size].append(
531 {
532 "predictions": sequence,
533 "mass_error": remaining_mass[batch_idx].item(),
534 "meets_precursor": remaining_meets_precursor[batch_idx].item(),
535 "prediction_log_probability": log_probabilities[batch_idx, 0].item(),
536 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
537 }
538 )
540 # Get top n beams per batch
541 # Filtered on meets_precursor and log_probability
542 top_completed_beams = self._get_top_n_beams(completed_beams, beam_size)
544 # Prepare result dictionary
545 result: dict[str, Any] = {
546 "predictions": [],
547 # "mass_error": [],
548 "prediction_log_probability": [],
549 "prediction_token_log_probabilities": [],
550 }
551 if return_beam:
552 for i in range(beam_size):
553 result[f"predictions_beam_{i}"] = []
554 # result[f"mass_error_beam_{i}"] = []
555 result[f"predictions_log_probability_beam_{i}"] = []
556 result[f"predictions_token_log_probabilities_beam_{i}"] = []
558 for batch_idx in range(batch_size):
559 if return_beam:
560 for beam_idx in range(beam_size):
561 result[f"predictions_beam_{beam_idx}"].append("".join(top_completed_beams[batch_idx][beam_idx]["predictions"]))
562 # result[f"mass_error_beam_{beam_idx}"].append(top_completed_beams[batch_idx][beam_idx]["mass_error"])
563 result[f"predictions_log_probability_beam_{beam_idx}"].append(
564 top_completed_beams[batch_idx][beam_idx]["prediction_log_probability"]
565 )
566 result[f"predictions_token_log_probabilities_beam_{beam_idx}"].append(
567 top_completed_beams[batch_idx][beam_idx]["prediction_token_log_probabilities"]
568 )
570 # Save best beam as main result
571 result["predictions"].append(top_completed_beams[batch_idx][0]["predictions"])
572 # result[f"mass_error"].append(top_completed_beams[batch_idx][0]["mass_error"])
573 result["prediction_log_probability"].append(top_completed_beams[batch_idx][0]["prediction_log_probability"])
574 result["prediction_token_log_probabilities"].append(top_completed_beams[batch_idx][0]["prediction_token_log_probabilities"])
576 # Optionally include encoder output
577 if return_encoder_output:
578 # Reduce along sequence length dimension
579 encoder_output = spectrum_encoding.float().cpu()
580 encoder_mask = (1 - spectrum_mask.float()).cpu()
581 encoder_output = encoder_output * encoder_mask.unsqueeze(-1)
582 if encoder_output_reduction == "mean":
583 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1)
584 encoder_output = encoder_output.sum(dim=1) / count
585 elif encoder_output_reduction == "max":
586 encoder_output[encoder_output == 0] = -float("inf")
587 encoder_output = encoder_output.max(dim=1)[0]
588 elif encoder_output_reduction == "sum":
589 encoder_output = encoder_output.sum(dim=1)
590 elif encoder_output_reduction == "full":
591 raise NotImplementedError("Full encoder output reduction is not yet implemented")
592 else:
593 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}")
594 result["encoder_output"] = list(encoder_output.numpy())
596 return result
598 def _get_top_n_beams(self, completed_beams: list[list[dict[str, Any]]], beam_size: int) -> list[list[dict[str, Any]]]:
599 """Get the top n beams from the completed beams.
601 Args:
602 completed_beams: The completed beams to get the top n beams from.
603 Each beam is a dictionary with the following keys:
604 - predictions: The predictions of the beam.
605 - mass_error: The mass error of the beam.
606 - meets_precursor: Whether the beam meets the precursor mass.
607 - prediction_log_probability: The log probability of the beam.
608 - prediction_token_log_probabilities: The log probabilities of the tokens in the beam.
609 beam_size: The number of beams to keep per batch.
611 Returns:
612 A list of lists, each containing the top n beams for a batch.
613 """
614 default_beam = {
615 "predictions": [],
616 "mass_error": -float("inf"),
617 "prediction_log_probability": -float("inf"),
618 "prediction_token_log_probabilities": [],
619 }
621 top_beams_per_row = []
622 for beams in completed_beams:
623 # Sort first by error within tolerance, then by log_prob descending
624 beams.sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True)
626 # Keep top N beams
627 top_beams = beams[:beam_size]
629 # Pad with default beam if fewer than N
630 while len(top_beams) < beam_size:
631 top_beams.append(default_beam.copy())
633 top_beams_per_row.append(top_beams)
635 return top_beams_per_row