Coverage for instanovo/inference/beam_search.py: 85%
220 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Literal
5import torch
6from jaxtyping import Float
8from instanovo.__init__ import console
9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, MASS_SCALE, PrecursorDimension
10from instanovo.inference.interfaces import Decodable, Decoder
11from instanovo.types import PrecursorFeatures, Spectrum
12from instanovo.utils.colorlogging import ColorLog
14logger = ColorLog(console, __name__).logger
17class BeamSearchDecoder(Decoder):
18 """A class for decoding from de novo sequence models using beam search.
20 This class conforms to the `Decoder` interface and decodes from
21 models that conform to the `Decodable` interface.
22 """
24 def __init__(
25 self,
26 model: Decodable,
27 suppressed_residues: list[str] | None = None,
28 mass_scale: int = MASS_SCALE,
29 disable_terminal_residues_anywhere: bool = True,
30 keep_invalid_mass_sequences: bool = True,
31 float_dtype: torch.dtype = torch.float64,
32 ):
33 super().__init__(model=model)
34 self.mass_scale = mass_scale
35 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere
36 self.keep_invalid_mass_sequences = keep_invalid_mass_sequences
37 self.float_dtype = float_dtype
39 suppressed_residues = suppressed_residues or []
41 # NOTE: Greedy search requires `residue_set` class in the model,
42 # update all methods accordingly.
43 if not hasattr(model, "residue_set"):
44 raise AttributeError("The model is missing the required attribute: residue_set")
46 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000
47 # We would need to divide the scaled masses as we use floating point masses.
48 # These residue masses are per amino acid and include special tokens,
49 # special tokens have a mass of 0.
50 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype)
51 terminal_residues_idx: list[int] = []
52 suppressed_residues_idx: list[int] = []
54 # residue_target_offsets supports negative masses (overshoot the remaining mass)
55 # This fixes a bug where the residue prior to a negative mass residue is always invalid.
56 residue_target_offsets: list[float] = [0.0]
58 for i, residue in enumerate(model.residue_set.vocab):
59 if residue in self.model.residue_set.special_tokens:
60 continue
61 self.residue_masses[i] = self.model.residue_set.get_mass(residue)
62 # If no residue is attached, assume it is a n-terminal residue
63 if not residue[0].isalpha():
64 terminal_residues_idx.append(i)
65 if self.residue_masses[i] < 0:
66 residue_target_offsets.append(self.residue_masses[i])
68 # Check if residue is suppressed
69 if residue in suppressed_residues:
70 suppressed_residues_idx.append(i)
71 suppressed_residues.remove(residue)
73 if len(suppressed_residues) > 0:
74 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}")
76 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long)
77 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long)
78 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype)
80 self.vocab_size = len(self.model.residue_set)
82 def decode( # type:ignore
83 self,
84 spectra: Float[Spectrum, " batch"],
85 precursors: Float[PrecursorFeatures, " batch"],
86 beam_size: int,
87 max_length: int,
88 mass_tolerance: float = 5e-5,
89 max_isotope: int = 1,
90 min_log_prob: float = -float("inf"),
91 return_encoder_output: bool = False,
92 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean",
93 return_beam: bool = False,
94 **kwargs,
95 ) -> dict[str, Any]:
96 """Decode predicted residue sequence for a batch of spectra using beam search.
98 Args:
99 spectra (torch.FloatTensor):
100 The spectra to be sequenced.
102 precursors (torch.FloatTensor[batch size, 3]):
103 The precursor mass, charge and mass-to-charge ratio.
105 beam_size (int):
106 The maximum size of the beam.
107 Ignored in beam search.
109 max_length (int):
110 The maximum length of a residue sequence.
112 mass_tolerance (float):
113 The maximum relative error for which a predicted sequence
114 is still considered to have matched the precursor mass.
116 max_isotope (int):
117 The maximum number of additional neutrons for isotopes
118 whose mass a predicted sequence's mass is considered
119 when comparing to the precursor mass.
121 All additional nucleon numbers from 1 to `max_isotope` inclusive
122 are considered.
124 min_log_prob (float):
125 Minimum log probability to stop decoding early. If a sequence
126 probability is less than this value it is marked as complete.
127 Defaults to -inf.
129 return_beam (bool):
130 Optionally return beam-search results. Ignored in greedy search.
132 Returns:
133 list[list[str]]:
134 The predicted sequence as a list of residue tokens.
135 This method will return an empty list for each
136 spectrum in the batch where
137 decoding fails i.e. no sequence that fits the precursor mass
138 to within a tolerance is found.
139 """
140 # Beam search with precursor mass termination condition
141 batch_size = spectra.shape[0]
142 effective_batch_size = batch_size * beam_size
143 device = spectra.device
145 # Masses of all residues in vocabulary, 0 for special tokens
146 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, )
148 # ppm equivalent of mass tolerance
149 delta_ppm_tol = mass_tolerance * 10**6 # float (1, )
151 # Residue masses expanded (repeated) across batch_size
152 # This is used to quickly compute all possible remaining masses per vocab entry
153 residue_mass_delta = self.residue_masses.expand(effective_batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size)
155 # completed_items: list[list[ScoredSequence]] = [[] for _ in range(batch_size)]
156 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)]
158 with torch.no_grad():
159 # 1. Compute spectrum encoding and masks
160 # Encoder is only run once.
161 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors)
163 # EXPAND FOR BEAM SIZE
164 spectrum_encoding_expanded = spectrum_encoding.repeat_interleave(beam_size, dim=0)
165 spectrum_mask_expanded = spectrum_mask.repeat_interleave(beam_size, dim=0)
167 # 2. Initialise beams and other variables
168 # The sequences decoded so far, grows on index 1 for every decoding pass.
169 # sequence_length is variable!
170 sequences = torch.zeros((effective_batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length)
172 # Log probabilities of the sequences decoded so far,
173 # token probabilities are added at each step.
174 log_probabilities = torch.zeros((effective_batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1)
176 # Keeps track of which beams are completed, this allows the model to skip these
177 complete_beams = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, )
178 is_first_complete = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, )
180 # Extract precursor mass from `precursors`
181 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0)
183 precursor_mass = precursors_expanded[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, )
185 # Target mass delta, remaining mass x must be within `target > x > -target`.
186 # This target can shift with isotopes.
187 # Mass targets = error_ppm * m_prec * 1e-6
188 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, )
190 # This keeps track of the remaining mass budget for the currently decoding sequence,
191 # starts at the precursor - H2O
192 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, )
194 # TODO: only check when close to precursor mass? Might not be worth the overhead.
195 # Idea is if remaining < check_zone, we do the valid mass and complete checks.
196 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta
198 # Constant beam indices for retaining beams on failed decoding
199 constant_beam_indices = torch.arange(beam_size, device=device)[None, :].repeat_interleave(batch_size, dim=0)
201 # Store token probabilities
202 token_log_probabilities: dict[str, list[float]] = {} # dict[str, list[float]]
204 # Start decoding
205 for i in range(max_length):
206 # If all beams are complete, we can stop early.
207 if complete_beams.all():
208 break
210 # Step 3: score the next tokens
211 # NOTE: SOS token is appended automatically in `score_candidates`.
212 # We do not have to add it.
213 batch = (sequences, precursors_expanded, spectrum_encoding_expanded, spectrum_mask_expanded)
214 next_token_probabilities = self.model.score_candidates(*batch)
216 # Step 4: Filter probabilities
217 # If remaining mass is within tolerance, we force an EOS token.
218 # All tokens that would set the remaining mass below the minimum
219 # cutoff `-mass_target_delta` including isotopes is set to -inf
221 # Step 4.1: Check if remaining mass is within tolerance:
222 # To keep it efficient we compute some of the indexed variables first:
223 remaining_meets_precursor = torch.zeros((effective_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, )
224 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive)
225 for j in range(0, max_isotope + 1, 1):
226 # TODO: Use vectorized approach for this
227 isotope = CARBON_MASS_DELTA * j # float
228 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta # bool (sub_batch_size, )
229 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta # bool (sub_batch_size, )
231 # remaining mass is within the target tolerance
232 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, )
233 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, )
234 if remaining_within_range.any() and j > 0:
235 # If we did hit an isotope, correct the remaining mass accordingly
236 # TODO check this
237 remaining_mass[remaining_within_range] = remaining_mass[remaining_within_range] - isotope
239 # Step 4.2: Check which residues are valid
240 # Expand incomplete remaining mass across vocabulary size
241 remaining_mass_expanded = remaining_mass[:, None].expand(
242 effective_batch_size, self.vocab_size
243 ) # float64 (effective_batch_size, vocab_size)
244 mass_target_expanded = mass_target_delta[:, None].expand(
245 effective_batch_size, self.vocab_size
246 ) # float64 (effective_batch_size, vocab_size)
248 valid_mass = remaining_mass_expanded - residue_mass_delta > -mass_target_expanded # bool (effective_batch_size, vocab_size)
249 # Check all isotopes for valid masses
250 for mass_offset in self.residue_target_offsets:
251 for j in range(0, max_isotope + 1, 1):
252 isotope = CARBON_MASS_DELTA * j # float
253 mass_lesser_isotope = (
254 remaining_mass_expanded - residue_mass_delta < mass_target_expanded + isotope + mass_offset
255 ) # bool (effective_batch_size, vocab_size)
256 mass_greater_isotope = (
257 remaining_mass_expanded - residue_mass_delta > -mass_target_expanded + isotope + mass_offset
258 ) # bool (effective_batch_size, vocab_size)
259 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (effective_batch_size, vocab_size)
261 # Filtered probabilities:
262 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (effective_batch_size, vocab_size)
263 # If mass is invalid, set log_prob to -inf
264 next_token_probabilities_filtered[~valid_mass] = -float("inf")
266 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf")
267 # Allow the model to predict PAD when all residues are -inf
268 next_token_probabilities_filtered[:, self.model.residue_set.PAD_INDEX] = -float("inf")
269 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf")
270 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf")
271 # Set probability of n-terminal modifications to -inf when i > 0
272 if self.disable_terminal_residues_anywhere:
273 # Check if adding terminal residues would result in a complete sequence
274 # First generate remaining mass matrix with isotopes
275 remaining_mass_isotope = remaining_mass[:, None].expand(effective_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * (
276 torch.arange(max_isotope + 1, device=device)
277 )
278 # Expand with terminal residues and subtract
279 remaining_mass_isotope_delta = (
280 remaining_mass_isotope[:, :, None].expand(
281 effective_batch_size,
282 max_isotope + 1,
283 self.terminal_residue_indices.shape[0],
284 )
285 - self.residue_masses[self.terminal_residue_indices]
286 )
288 # If within target delta, allow these residues to be predicted,
289 # otherwise set probability to -inf
290 allow_terminal = (remaining_mass_isotope_delta.abs() < mass_target_delta[:, None, None]).any(dim=1)
291 allow_terminal_full = torch.ones(
292 (effective_batch_size, self.vocab_size),
293 device=spectra.device,
294 dtype=bool,
295 )
296 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal
298 # Set to -inf
299 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf")
301 # Set to -inf for newly completed beams, only allow EOS
302 # NEW WAY TO FORCE EOS
303 # for beam_idx in remaining_meets_precursor:
304 next_beam_no_predictions = next_token_probabilities_filtered.isinf().all(-1)
306 if is_first_complete.any():
307 completed_idxs = is_first_complete.nonzero().squeeze(-1)
308 for beam_idx in completed_idxs:
309 sequence_probability = (
310 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx,
311 )
312 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()])
313 sequence = self.model.decode(sequences[beam_idx])
314 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]}
315 if "".join(sequence) in seen_completed_sequences:
316 continue
317 completed_beams[beam_idx // beam_size].append(
318 {
319 "predictions": sequence,
320 "mass_error": remaining_mass[beam_idx].item(),
321 "meets_precursor": remaining_meets_precursor[beam_idx].item(),
322 "prediction_log_probability": sequence_probability.item(),
323 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
324 }
325 )
327 # print(sequences[:5])
329 # For beams that already meet precursor, -inf them and force an EOS
330 next_token_probabilities_filtered[remaining_meets_precursor, :] = -float("inf")
331 if self.keep_invalid_mass_sequences:
332 # Allow EOS on beams that dont fit precursor
333 allow_eos = (remaining_meets_precursor | next_beam_no_predictions) & ~complete_beams
334 else:
335 allow_eos = (remaining_meets_precursor) & ~complete_beams
336 next_eos_probs = next_token_probabilities[allow_eos, self.model.residue_set.EOS_INDEX]
337 next_token_probabilities_filtered[allow_eos, self.model.residue_set.EOS_INDEX] = next_eos_probs
339 # Step 5: Select next token:
340 log_probabilities_expanded = log_probabilities.repeat_interleave(self.vocab_size, dim=1)
341 log_probabilities_expanded = log_probabilities_expanded + next_token_probabilities_filtered
343 log_probabilities_beams = log_probabilities_expanded.view(-1, beam_size, self.vocab_size)
344 if i == 0 and beam_size > 1:
345 # Nullify all beams except the first one
346 log_probabilities_beams[:, 1:] = -float("inf")
348 log_probabilities_beams = log_probabilities_beams.view(-1, beam_size * self.vocab_size)
350 topk_values, topk_indices = log_probabilities_beams.topk(beam_size, dim=-1)
351 topk_is_inf = topk_values.isinf()
353 beam_indices = topk_indices // self.vocab_size
354 # Retain beams on failed decoding (when all beams are -inf)
355 beam_indices[topk_is_inf] = constant_beam_indices[topk_is_inf]
356 beam_indices_full = (beam_indices + torch.arange(batch_size, device=beam_indices.device)[:, None] * beam_size).view(-1)
358 next_token = topk_indices % self.vocab_size
359 next_token[topk_is_inf] = self.model.residue_set.PAD_INDEX
360 next_token = next_token.view(-1, 1) # long (sub_batch_size, 1)\
362 # Update beams by indices
363 sequences = sequences[beam_indices_full]
364 log_probabilities = log_probabilities[beam_indices_full]
365 next_token_probabilities = next_token_probabilities[beam_indices_full]
366 remaining_mass = remaining_mass[beam_indices_full]
367 complete_beams = complete_beams[beam_indices_full]
369 sequences = torch.concat([sequences, next_token], axis=1) # long (batch_size, 1)
371 # Expand and update masses
372 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, )
373 remaining_mass = remaining_mass - next_masses # float64 (batch_size, )
375 # Expand and update probabilities
376 next_token_probabilities[:, self.model.residue_set.PAD_INDEX] = 0
377 next_probabilities = torch.gather(next_token_probabilities, 1, next_token)
378 next_probabilities[complete_beams] = 0
379 log_probabilities = log_probabilities + next_probabilities
381 for batch_index in range(effective_batch_size):
382 # Create unique ID for the sequence
383 # Store beam token probabilities in a hash table
384 spectrum_index = batch_index // beam_size
385 sequence = [str(x) for x in sequences[batch_index].cpu().tolist()]
386 sequence_str = str(spectrum_index) + "-" + ".".join(sequence)
387 sequence_prev_str = str(spectrum_index) + "-" + ".".join(sequence[:-1])
389 if sequence_prev_str in token_log_probabilities:
390 previous_probabilities = list(token_log_probabilities[sequence_prev_str])
391 else:
392 previous_probabilities = []
394 previous_probabilities.append(next_probabilities[batch_index, 0].float().item())
396 token_log_probabilities[sequence_str] = previous_probabilities
398 # Step 6: Terminate complete beams
400 # Check if complete:
401 # Early stopping if beam log probability below threshold
402 beam_confidence_filter = log_probabilities[:, 0] < min_log_prob
403 # Stop if beam is forced to output an EOS
404 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX
405 next_token_is_pad = next_token[:, 0] == self.model.residue_set.PAD_INDEX
406 next_is_complete = next_token_is_eos | beam_confidence_filter # | next_token_is_pad
408 complete_beams = complete_beams | next_is_complete
409 is_first_complete = next_is_complete
411 if next_token_is_pad.all():
412 break
414 # Repeat from step 3.
416 # Check if any beams are complete at the end of the loop
417 if is_first_complete.any():
418 completed_idxs = is_first_complete.nonzero().squeeze(-1)
419 for beam_idx in completed_idxs:
420 sequence_probability = (
421 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx,
422 # self.model.residue_set.EOS_INDEX]
423 )
424 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()])
425 sequence = self.model.decode(sequences[beam_idx])
426 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]}
427 if "".join(sequence) in seen_completed_sequences:
428 continue
429 completed_beams[beam_idx // beam_size].append(
430 {
431 "predictions": sequence,
432 "mass_error": remaining_mass[beam_idx].item(),
433 "meets_precursor": remaining_meets_precursor[beam_idx].item(),
434 "prediction_log_probability": sequence_probability.item(),
435 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
436 }
437 )
439 # This loop forcefully adds all beams at the end, whether they are complete or not
440 if self.keep_invalid_mass_sequences:
441 for batch_idx in range(effective_batch_size):
442 sequence_str = str(batch_idx // beam_size) + "-" + ".".join([str(x) for x in sequences[batch_idx].cpu().tolist()])
443 sequence = self.model.decode(sequences[batch_idx])
444 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[batch_idx // beam_size]}
445 if "".join(sequence) in seen_completed_sequences:
446 # print(f"Skipping {sequence_str} because it is added")
447 continue
448 completed_beams[batch_idx // beam_size].append(
449 {
450 "predictions": sequence,
451 "mass_error": remaining_mass[batch_idx].item(),
452 "meets_precursor": remaining_meets_precursor[batch_idx].item(),
453 "prediction_log_probability": log_probabilities[batch_idx, 0].item(),
454 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1],
455 }
456 )
458 # Get top n beams per batch
459 # Filtered on meets_precursor and log_probability
460 top_completed_beams = self._get_top_n_beams(completed_beams, beam_size)
462 # Prepare result dictionary
463 result: dict[str, Any] = {
464 "predictions": [],
465 # "mass_error": [],
466 "prediction_log_probability": [],
467 "prediction_token_log_probabilities": [],
468 }
469 if return_beam:
470 for i in range(beam_size):
471 result[f"predictions_beam_{i}"] = []
472 # result[f"mass_error_beam_{i}"] = []
473 result[f"predictions_log_probability_beam_{i}"] = []
474 result[f"predictions_token_log_probabilities_beam_{i}"] = []
476 for batch_idx in range(batch_size):
477 if return_beam:
478 for beam_idx in range(beam_size):
479 result[f"predictions_beam_{beam_idx}"].append("".join(top_completed_beams[batch_idx][beam_idx]["predictions"]))
480 # result[f"mass_error_beam_{beam_idx}"].append(top_completed_beams[batch_idx][beam_idx]["mass_error"])
481 result[f"predictions_log_probability_beam_{beam_idx}"].append(
482 top_completed_beams[batch_idx][beam_idx]["prediction_log_probability"]
483 )
484 result[f"predictions_token_log_probabilities_beam_{beam_idx}"].append(
485 top_completed_beams[batch_idx][beam_idx]["prediction_token_log_probabilities"]
486 )
488 # Save best beam as main result
489 result["predictions"].append(top_completed_beams[batch_idx][0]["predictions"])
490 # result[f"mass_error"].append(top_completed_beams[batch_idx][0]["mass_error"])
491 result["prediction_log_probability"].append(top_completed_beams[batch_idx][0]["prediction_log_probability"])
492 result["prediction_token_log_probabilities"].append(top_completed_beams[batch_idx][0]["prediction_token_log_probabilities"])
494 # Optionally include encoder output
495 if return_encoder_output:
496 # Reduce along sequence length dimension
497 encoder_output = spectrum_encoding.float().cpu()
498 encoder_mask = (1 - spectrum_mask.float()).cpu()
499 encoder_output = encoder_output * encoder_mask.unsqueeze(-1)
500 if encoder_output_reduction == "mean":
501 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1)
502 encoder_output = encoder_output.sum(dim=1) / count
503 elif encoder_output_reduction == "max":
504 encoder_output[encoder_output == 0] = -float("inf")
505 encoder_output = encoder_output.max(dim=1)[0]
506 elif encoder_output_reduction == "sum":
507 encoder_output = encoder_output.sum(dim=1)
508 elif encoder_output_reduction == "full":
509 raise NotImplementedError("Full encoder output reduction is not yet implemented")
510 else:
511 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}")
512 result["encoder_output"] = list(encoder_output.numpy())
514 return result
516 def _get_top_n_beams(self, completed_beams: list[list[dict[str, Any]]], beam_size: int) -> list[list[dict[str, Any]]]:
517 """Get the top n beams from the completed beams.
519 Args:
520 completed_beams: The completed beams to get the top n beams from.
521 Each beam is a dictionary with the following keys:
522 - predictions: The predictions of the beam.
523 - mass_error: The mass error of the beam.
524 - meets_precursor: Whether the beam meets the precursor mass.
525 - prediction_log_probability: The log probability of the beam.
526 - prediction_token_log_probabilities: The log probabilities of the tokens in the beam.
527 beam_size: The number of beams to keep per batch.
529 Returns:
530 A list of lists, each containing the top n beams for a batch.
531 """
532 default_beam = {
533 "predictions": [],
534 "mass_error": -float("inf"),
535 "prediction_log_probability": -float("inf"),
536 "prediction_token_log_probabilities": [],
537 }
539 top_beams_per_row = []
540 for beams in completed_beams:
541 # Sort first by error within tolerance, then by log_prob descending
542 beams.sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True)
544 # Keep top N beams
545 top_beams = beams[:beam_size]
547 # Pad with default beam if fewer than N
548 while len(top_beams) < beam_size:
549 top_beams.append(default_beam.copy())
551 top_beams_per_row.append(top_beams)
553 return top_beams_per_row