Coverage for instanovo/inference/beam_search.py: 85%

220 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Literal 

4 

5import torch 

6from jaxtyping import Float 

7 

8from instanovo.__init__ import console 

9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, MASS_SCALE, PrecursorDimension 

10from instanovo.inference.interfaces import Decodable, Decoder 

11from instanovo.types import PrecursorFeatures, Spectrum 

12from instanovo.utils.colorlogging import ColorLog 

13 

14logger = ColorLog(console, __name__).logger 

15 

16 

17class BeamSearchDecoder(Decoder): 

18 """A class for decoding from de novo sequence models using beam search. 

19 

20 This class conforms to the `Decoder` interface and decodes from 

21 models that conform to the `Decodable` interface. 

22 """ 

23 

24 def __init__( 

25 self, 

26 model: Decodable, 

27 suppressed_residues: list[str] | None = None, 

28 mass_scale: int = MASS_SCALE, 

29 disable_terminal_residues_anywhere: bool = True, 

30 keep_invalid_mass_sequences: bool = True, 

31 float_dtype: torch.dtype = torch.float64, 

32 ): 

33 super().__init__(model=model) 

34 self.mass_scale = mass_scale 

35 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere 

36 self.keep_invalid_mass_sequences = keep_invalid_mass_sequences 

37 self.float_dtype = float_dtype 

38 

39 suppressed_residues = suppressed_residues or [] 

40 

41 # NOTE: Greedy search requires `residue_set` class in the model, 

42 # update all methods accordingly. 

43 if not hasattr(model, "residue_set"): 

44 raise AttributeError("The model is missing the required attribute: residue_set") 

45 

46 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000 

47 # We would need to divide the scaled masses as we use floating point masses. 

48 # These residue masses are per amino acid and include special tokens, 

49 # special tokens have a mass of 0. 

50 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype) 

51 terminal_residues_idx: list[int] = [] 

52 suppressed_residues_idx: list[int] = [] 

53 

54 # residue_target_offsets supports negative masses (overshoot the remaining mass) 

55 # This fixes a bug where the residue prior to a negative mass residue is always invalid. 

56 residue_target_offsets: list[float] = [0.0] 

57 

58 for i, residue in enumerate(model.residue_set.vocab): 

59 if residue in self.model.residue_set.special_tokens: 

60 continue 

61 self.residue_masses[i] = self.model.residue_set.get_mass(residue) 

62 # If no residue is attached, assume it is a n-terminal residue 

63 if not residue[0].isalpha(): 

64 terminal_residues_idx.append(i) 

65 if self.residue_masses[i] < 0: 

66 residue_target_offsets.append(self.residue_masses[i]) 

67 

68 # Check if residue is suppressed 

69 if residue in suppressed_residues: 

70 suppressed_residues_idx.append(i) 

71 suppressed_residues.remove(residue) 

72 

73 if len(suppressed_residues) > 0: 

74 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}") 

75 

76 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long) 

77 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long) 

78 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype) 

79 

80 self.vocab_size = len(self.model.residue_set) 

81 

82 def decode( # type:ignore 

83 self, 

84 spectra: Float[Spectrum, " batch"], 

85 precursors: Float[PrecursorFeatures, " batch"], 

86 beam_size: int, 

87 max_length: int, 

88 mass_tolerance: float = 5e-5, 

89 max_isotope: int = 1, 

90 min_log_prob: float = -float("inf"), 

91 return_encoder_output: bool = False, 

92 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean", 

93 return_beam: bool = False, 

94 **kwargs, 

95 ) -> dict[str, Any]: 

96 """Decode predicted residue sequence for a batch of spectra using beam search. 

97 

98 Args: 

99 spectra (torch.FloatTensor): 

100 The spectra to be sequenced. 

101 

102 precursors (torch.FloatTensor[batch size, 3]): 

103 The precursor mass, charge and mass-to-charge ratio. 

104 

105 beam_size (int): 

106 The maximum size of the beam. 

107 Ignored in beam search. 

108 

109 max_length (int): 

110 The maximum length of a residue sequence. 

111 

112 mass_tolerance (float): 

113 The maximum relative error for which a predicted sequence 

114 is still considered to have matched the precursor mass. 

115 

116 max_isotope (int): 

117 The maximum number of additional neutrons for isotopes 

118 whose mass a predicted sequence's mass is considered 

119 when comparing to the precursor mass. 

120 

121 All additional nucleon numbers from 1 to `max_isotope` inclusive 

122 are considered. 

123 

124 min_log_prob (float): 

125 Minimum log probability to stop decoding early. If a sequence 

126 probability is less than this value it is marked as complete. 

127 Defaults to -inf. 

128 

129 return_beam (bool): 

130 Optionally return beam-search results. Ignored in greedy search. 

131 

132 Returns: 

133 list[list[str]]: 

134 The predicted sequence as a list of residue tokens. 

135 This method will return an empty list for each 

136 spectrum in the batch where 

137 decoding fails i.e. no sequence that fits the precursor mass 

138 to within a tolerance is found. 

139 """ 

140 # Beam search with precursor mass termination condition 

141 batch_size = spectra.shape[0] 

142 effective_batch_size = batch_size * beam_size 

143 device = spectra.device 

144 

145 # Masses of all residues in vocabulary, 0 for special tokens 

146 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, ) 

147 

148 # ppm equivalent of mass tolerance 

149 delta_ppm_tol = mass_tolerance * 10**6 # float (1, ) 

150 

151 # Residue masses expanded (repeated) across batch_size 

152 # This is used to quickly compute all possible remaining masses per vocab entry 

153 residue_mass_delta = self.residue_masses.expand(effective_batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size) 

154 

155 # completed_items: list[list[ScoredSequence]] = [[] for _ in range(batch_size)] 

156 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)] 

157 

158 with torch.no_grad(): 

159 # 1. Compute spectrum encoding and masks 

160 # Encoder is only run once. 

161 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors) 

162 

163 # EXPAND FOR BEAM SIZE 

164 spectrum_encoding_expanded = spectrum_encoding.repeat_interleave(beam_size, dim=0) 

165 spectrum_mask_expanded = spectrum_mask.repeat_interleave(beam_size, dim=0) 

166 

167 # 2. Initialise beams and other variables 

168 # The sequences decoded so far, grows on index 1 for every decoding pass. 

169 # sequence_length is variable! 

170 sequences = torch.zeros((effective_batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length) 

171 

172 # Log probabilities of the sequences decoded so far, 

173 # token probabilities are added at each step. 

174 log_probabilities = torch.zeros((effective_batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1) 

175 

176 # Keeps track of which beams are completed, this allows the model to skip these 

177 complete_beams = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, ) 

178 is_first_complete = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, ) 

179 

180 # Extract precursor mass from `precursors` 

181 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0) 

182 

183 precursor_mass = precursors_expanded[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, ) 

184 

185 # Target mass delta, remaining mass x must be within `target > x > -target`. 

186 # This target can shift with isotopes. 

187 # Mass targets = error_ppm * m_prec * 1e-6 

188 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, ) 

189 

190 # This keeps track of the remaining mass budget for the currently decoding sequence, 

191 # starts at the precursor - H2O 

192 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, ) 

193 

194 # TODO: only check when close to precursor mass? Might not be worth the overhead. 

195 # Idea is if remaining < check_zone, we do the valid mass and complete checks. 

196 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta 

197 

198 # Constant beam indices for retaining beams on failed decoding 

199 constant_beam_indices = torch.arange(beam_size, device=device)[None, :].repeat_interleave(batch_size, dim=0) 

200 

201 # Store token probabilities 

202 token_log_probabilities: dict[str, list[float]] = {} # dict[str, list[float]] 

203 

204 # Start decoding 

205 for i in range(max_length): 

206 # If all beams are complete, we can stop early. 

207 if complete_beams.all(): 

208 break 

209 

210 # Step 3: score the next tokens 

211 # NOTE: SOS token is appended automatically in `score_candidates`. 

212 # We do not have to add it. 

213 batch = (sequences, precursors_expanded, spectrum_encoding_expanded, spectrum_mask_expanded) 

214 next_token_probabilities = self.model.score_candidates(*batch) 

215 

216 # Step 4: Filter probabilities 

217 # If remaining mass is within tolerance, we force an EOS token. 

218 # All tokens that would set the remaining mass below the minimum 

219 # cutoff `-mass_target_delta` including isotopes is set to -inf 

220 

221 # Step 4.1: Check if remaining mass is within tolerance: 

222 # To keep it efficient we compute some of the indexed variables first: 

223 remaining_meets_precursor = torch.zeros((effective_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, ) 

224 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive) 

225 for j in range(0, max_isotope + 1, 1): 

226 # TODO: Use vectorized approach for this 

227 isotope = CARBON_MASS_DELTA * j # float 

228 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta # bool (sub_batch_size, ) 

229 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta # bool (sub_batch_size, ) 

230 

231 # remaining mass is within the target tolerance 

232 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, ) 

233 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, ) 

234 if remaining_within_range.any() and j > 0: 

235 # If we did hit an isotope, correct the remaining mass accordingly 

236 # TODO check this 

237 remaining_mass[remaining_within_range] = remaining_mass[remaining_within_range] - isotope 

238 

239 # Step 4.2: Check which residues are valid 

240 # Expand incomplete remaining mass across vocabulary size 

241 remaining_mass_expanded = remaining_mass[:, None].expand( 

242 effective_batch_size, self.vocab_size 

243 ) # float64 (effective_batch_size, vocab_size) 

244 mass_target_expanded = mass_target_delta[:, None].expand( 

245 effective_batch_size, self.vocab_size 

246 ) # float64 (effective_batch_size, vocab_size) 

247 

248 valid_mass = remaining_mass_expanded - residue_mass_delta > -mass_target_expanded # bool (effective_batch_size, vocab_size) 

249 # Check all isotopes for valid masses 

250 for mass_offset in self.residue_target_offsets: 

251 for j in range(0, max_isotope + 1, 1): 

252 isotope = CARBON_MASS_DELTA * j # float 

253 mass_lesser_isotope = ( 

254 remaining_mass_expanded - residue_mass_delta < mass_target_expanded + isotope + mass_offset 

255 ) # bool (effective_batch_size, vocab_size) 

256 mass_greater_isotope = ( 

257 remaining_mass_expanded - residue_mass_delta > -mass_target_expanded + isotope + mass_offset 

258 ) # bool (effective_batch_size, vocab_size) 

259 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (effective_batch_size, vocab_size) 

260 

261 # Filtered probabilities: 

262 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (effective_batch_size, vocab_size) 

263 # If mass is invalid, set log_prob to -inf 

264 next_token_probabilities_filtered[~valid_mass] = -float("inf") 

265 

266 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf") 

267 # Allow the model to predict PAD when all residues are -inf 

268 next_token_probabilities_filtered[:, self.model.residue_set.PAD_INDEX] = -float("inf") 

269 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf") 

270 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf") 

271 # Set probability of n-terminal modifications to -inf when i > 0 

272 if self.disable_terminal_residues_anywhere: 

273 # Check if adding terminal residues would result in a complete sequence 

274 # First generate remaining mass matrix with isotopes 

275 remaining_mass_isotope = remaining_mass[:, None].expand(effective_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * ( 

276 torch.arange(max_isotope + 1, device=device) 

277 ) 

278 # Expand with terminal residues and subtract 

279 remaining_mass_isotope_delta = ( 

280 remaining_mass_isotope[:, :, None].expand( 

281 effective_batch_size, 

282 max_isotope + 1, 

283 self.terminal_residue_indices.shape[0], 

284 ) 

285 - self.residue_masses[self.terminal_residue_indices] 

286 ) 

287 

288 # If within target delta, allow these residues to be predicted, 

289 # otherwise set probability to -inf 

290 allow_terminal = (remaining_mass_isotope_delta.abs() < mass_target_delta[:, None, None]).any(dim=1) 

291 allow_terminal_full = torch.ones( 

292 (effective_batch_size, self.vocab_size), 

293 device=spectra.device, 

294 dtype=bool, 

295 ) 

296 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal 

297 

298 # Set to -inf 

299 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf") 

300 

301 # Set to -inf for newly completed beams, only allow EOS 

302 # NEW WAY TO FORCE EOS 

303 # for beam_idx in remaining_meets_precursor: 

304 next_beam_no_predictions = next_token_probabilities_filtered.isinf().all(-1) 

305 

306 if is_first_complete.any(): 

307 completed_idxs = is_first_complete.nonzero().squeeze(-1) 

308 for beam_idx in completed_idxs: 

309 sequence_probability = ( 

310 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx, 

311 ) 

312 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()]) 

313 sequence = self.model.decode(sequences[beam_idx]) 

314 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]} 

315 if "".join(sequence) in seen_completed_sequences: 

316 continue 

317 completed_beams[beam_idx // beam_size].append( 

318 { 

319 "predictions": sequence, 

320 "mass_error": remaining_mass[beam_idx].item(), 

321 "meets_precursor": remaining_meets_precursor[beam_idx].item(), 

322 "prediction_log_probability": sequence_probability.item(), 

323 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

324 } 

325 ) 

326 

327 # print(sequences[:5]) 

328 

329 # For beams that already meet precursor, -inf them and force an EOS 

330 next_token_probabilities_filtered[remaining_meets_precursor, :] = -float("inf") 

331 if self.keep_invalid_mass_sequences: 

332 # Allow EOS on beams that dont fit precursor 

333 allow_eos = (remaining_meets_precursor | next_beam_no_predictions) & ~complete_beams 

334 else: 

335 allow_eos = (remaining_meets_precursor) & ~complete_beams 

336 next_eos_probs = next_token_probabilities[allow_eos, self.model.residue_set.EOS_INDEX] 

337 next_token_probabilities_filtered[allow_eos, self.model.residue_set.EOS_INDEX] = next_eos_probs 

338 

339 # Step 5: Select next token: 

340 log_probabilities_expanded = log_probabilities.repeat_interleave(self.vocab_size, dim=1) 

341 log_probabilities_expanded = log_probabilities_expanded + next_token_probabilities_filtered 

342 

343 log_probabilities_beams = log_probabilities_expanded.view(-1, beam_size, self.vocab_size) 

344 if i == 0 and beam_size > 1: 

345 # Nullify all beams except the first one 

346 log_probabilities_beams[:, 1:] = -float("inf") 

347 

348 log_probabilities_beams = log_probabilities_beams.view(-1, beam_size * self.vocab_size) 

349 

350 topk_values, topk_indices = log_probabilities_beams.topk(beam_size, dim=-1) 

351 topk_is_inf = topk_values.isinf() 

352 

353 beam_indices = topk_indices // self.vocab_size 

354 # Retain beams on failed decoding (when all beams are -inf) 

355 beam_indices[topk_is_inf] = constant_beam_indices[topk_is_inf] 

356 beam_indices_full = (beam_indices + torch.arange(batch_size, device=beam_indices.device)[:, None] * beam_size).view(-1) 

357 

358 next_token = topk_indices % self.vocab_size 

359 next_token[topk_is_inf] = self.model.residue_set.PAD_INDEX 

360 next_token = next_token.view(-1, 1) # long (sub_batch_size, 1)\ 

361 

362 # Update beams by indices 

363 sequences = sequences[beam_indices_full] 

364 log_probabilities = log_probabilities[beam_indices_full] 

365 next_token_probabilities = next_token_probabilities[beam_indices_full] 

366 remaining_mass = remaining_mass[beam_indices_full] 

367 complete_beams = complete_beams[beam_indices_full] 

368 

369 sequences = torch.concat([sequences, next_token], axis=1) # long (batch_size, 1) 

370 

371 # Expand and update masses 

372 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, ) 

373 remaining_mass = remaining_mass - next_masses # float64 (batch_size, ) 

374 

375 # Expand and update probabilities 

376 next_token_probabilities[:, self.model.residue_set.PAD_INDEX] = 0 

377 next_probabilities = torch.gather(next_token_probabilities, 1, next_token) 

378 next_probabilities[complete_beams] = 0 

379 log_probabilities = log_probabilities + next_probabilities 

380 

381 for batch_index in range(effective_batch_size): 

382 # Create unique ID for the sequence 

383 # Store beam token probabilities in a hash table 

384 spectrum_index = batch_index // beam_size 

385 sequence = [str(x) for x in sequences[batch_index].cpu().tolist()] 

386 sequence_str = str(spectrum_index) + "-" + ".".join(sequence) 

387 sequence_prev_str = str(spectrum_index) + "-" + ".".join(sequence[:-1]) 

388 

389 if sequence_prev_str in token_log_probabilities: 

390 previous_probabilities = list(token_log_probabilities[sequence_prev_str]) 

391 else: 

392 previous_probabilities = [] 

393 

394 previous_probabilities.append(next_probabilities[batch_index, 0].float().item()) 

395 

396 token_log_probabilities[sequence_str] = previous_probabilities 

397 

398 # Step 6: Terminate complete beams 

399 

400 # Check if complete: 

401 # Early stopping if beam log probability below threshold 

402 beam_confidence_filter = log_probabilities[:, 0] < min_log_prob 

403 # Stop if beam is forced to output an EOS 

404 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX 

405 next_token_is_pad = next_token[:, 0] == self.model.residue_set.PAD_INDEX 

406 next_is_complete = next_token_is_eos | beam_confidence_filter # | next_token_is_pad 

407 

408 complete_beams = complete_beams | next_is_complete 

409 is_first_complete = next_is_complete 

410 

411 if next_token_is_pad.all(): 

412 break 

413 

414 # Repeat from step 3. 

415 

416 # Check if any beams are complete at the end of the loop 

417 if is_first_complete.any(): 

418 completed_idxs = is_first_complete.nonzero().squeeze(-1) 

419 for beam_idx in completed_idxs: 

420 sequence_probability = ( 

421 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx, 

422 # self.model.residue_set.EOS_INDEX] 

423 ) 

424 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()]) 

425 sequence = self.model.decode(sequences[beam_idx]) 

426 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]} 

427 if "".join(sequence) in seen_completed_sequences: 

428 continue 

429 completed_beams[beam_idx // beam_size].append( 

430 { 

431 "predictions": sequence, 

432 "mass_error": remaining_mass[beam_idx].item(), 

433 "meets_precursor": remaining_meets_precursor[beam_idx].item(), 

434 "prediction_log_probability": sequence_probability.item(), 

435 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

436 } 

437 ) 

438 

439 # This loop forcefully adds all beams at the end, whether they are complete or not 

440 if self.keep_invalid_mass_sequences: 

441 for batch_idx in range(effective_batch_size): 

442 sequence_str = str(batch_idx // beam_size) + "-" + ".".join([str(x) for x in sequences[batch_idx].cpu().tolist()]) 

443 sequence = self.model.decode(sequences[batch_idx]) 

444 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[batch_idx // beam_size]} 

445 if "".join(sequence) in seen_completed_sequences: 

446 # print(f"Skipping {sequence_str} because it is added") 

447 continue 

448 completed_beams[batch_idx // beam_size].append( 

449 { 

450 "predictions": sequence, 

451 "mass_error": remaining_mass[batch_idx].item(), 

452 "meets_precursor": remaining_meets_precursor[batch_idx].item(), 

453 "prediction_log_probability": log_probabilities[batch_idx, 0].item(), 

454 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

455 } 

456 ) 

457 

458 # Get top n beams per batch 

459 # Filtered on meets_precursor and log_probability 

460 top_completed_beams = self._get_top_n_beams(completed_beams, beam_size) 

461 

462 # Prepare result dictionary 

463 result: dict[str, Any] = { 

464 "predictions": [], 

465 # "mass_error": [], 

466 "prediction_log_probability": [], 

467 "prediction_token_log_probabilities": [], 

468 } 

469 if return_beam: 

470 for i in range(beam_size): 

471 result[f"predictions_beam_{i}"] = [] 

472 # result[f"mass_error_beam_{i}"] = [] 

473 result[f"predictions_log_probability_beam_{i}"] = [] 

474 result[f"predictions_token_log_probabilities_beam_{i}"] = [] 

475 

476 for batch_idx in range(batch_size): 

477 if return_beam: 

478 for beam_idx in range(beam_size): 

479 result[f"predictions_beam_{beam_idx}"].append("".join(top_completed_beams[batch_idx][beam_idx]["predictions"])) 

480 # result[f"mass_error_beam_{beam_idx}"].append(top_completed_beams[batch_idx][beam_idx]["mass_error"]) 

481 result[f"predictions_log_probability_beam_{beam_idx}"].append( 

482 top_completed_beams[batch_idx][beam_idx]["prediction_log_probability"] 

483 ) 

484 result[f"predictions_token_log_probabilities_beam_{beam_idx}"].append( 

485 top_completed_beams[batch_idx][beam_idx]["prediction_token_log_probabilities"] 

486 ) 

487 

488 # Save best beam as main result 

489 result["predictions"].append(top_completed_beams[batch_idx][0]["predictions"]) 

490 # result[f"mass_error"].append(top_completed_beams[batch_idx][0]["mass_error"]) 

491 result["prediction_log_probability"].append(top_completed_beams[batch_idx][0]["prediction_log_probability"]) 

492 result["prediction_token_log_probabilities"].append(top_completed_beams[batch_idx][0]["prediction_token_log_probabilities"]) 

493 

494 # Optionally include encoder output 

495 if return_encoder_output: 

496 # Reduce along sequence length dimension 

497 encoder_output = spectrum_encoding.float().cpu() 

498 encoder_mask = (1 - spectrum_mask.float()).cpu() 

499 encoder_output = encoder_output * encoder_mask.unsqueeze(-1) 

500 if encoder_output_reduction == "mean": 

501 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1) 

502 encoder_output = encoder_output.sum(dim=1) / count 

503 elif encoder_output_reduction == "max": 

504 encoder_output[encoder_output == 0] = -float("inf") 

505 encoder_output = encoder_output.max(dim=1)[0] 

506 elif encoder_output_reduction == "sum": 

507 encoder_output = encoder_output.sum(dim=1) 

508 elif encoder_output_reduction == "full": 

509 raise NotImplementedError("Full encoder output reduction is not yet implemented") 

510 else: 

511 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}") 

512 result["encoder_output"] = list(encoder_output.numpy()) 

513 

514 return result 

515 

516 def _get_top_n_beams(self, completed_beams: list[list[dict[str, Any]]], beam_size: int) -> list[list[dict[str, Any]]]: 

517 """Get the top n beams from the completed beams. 

518 

519 Args: 

520 completed_beams: The completed beams to get the top n beams from. 

521 Each beam is a dictionary with the following keys: 

522 - predictions: The predictions of the beam. 

523 - mass_error: The mass error of the beam. 

524 - meets_precursor: Whether the beam meets the precursor mass. 

525 - prediction_log_probability: The log probability of the beam. 

526 - prediction_token_log_probabilities: The log probabilities of the tokens in the beam. 

527 beam_size: The number of beams to keep per batch. 

528 

529 Returns: 

530 A list of lists, each containing the top n beams for a batch. 

531 """ 

532 default_beam = { 

533 "predictions": [], 

534 "mass_error": -float("inf"), 

535 "prediction_log_probability": -float("inf"), 

536 "prediction_token_log_probabilities": [], 

537 } 

538 

539 top_beams_per_row = [] 

540 for beams in completed_beams: 

541 # Sort first by error within tolerance, then by log_prob descending 

542 beams.sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True) 

543 

544 # Keep top N beams 

545 top_beams = beams[:beam_size] 

546 

547 # Pad with default beam if fewer than N 

548 while len(top_beams) < beam_size: 

549 top_beams.append(default_beam.copy()) 

550 

551 top_beams_per_row.append(top_beams) 

552 

553 return top_beams_per_row