Coverage for instanovo/inference/knapsack_beam_search.py: 86%

247 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Literal 

4 

5import torch 

6from jaxtyping import Float 

7 

8from instanovo.__init__ import console 

9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, PrecursorDimension 

10from instanovo.inference.interfaces import Decodable, Decoder 

11from instanovo.inference.knapsack import Knapsack 

12from instanovo.types import PrecursorFeatures, Spectrum 

13from instanovo.utils.colorlogging import ColorLog 

14 

15logger = ColorLog(console, __name__).logger 

16 

17 

18class KnapsackBeamSearchDecoder(Decoder): 

19 """A class for decoding from de novo sequence models using beam search. 

20 

21 This class conforms to the `Decoder` interface and decodes from 

22 models that conform to the `Decodable` interface. 

23 """ 

24 

25 def __init__( 

26 self, 

27 model: Decodable, 

28 knapsack: Knapsack, 

29 suppressed_residues: list[str] | None = None, 

30 disable_terminal_residues_anywhere: bool = True, 

31 keep_invalid_mass_sequences: bool = True, 

32 float_dtype: torch.dtype = torch.float64, 

33 ): 

34 super().__init__(model=model) 

35 self.knapsack = knapsack 

36 self.chart = torch.tensor(self.knapsack.chart) 

37 self.mass_scale = knapsack.mass_scale 

38 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere 

39 self.keep_invalid_mass_sequences = keep_invalid_mass_sequences 

40 self.float_dtype = float_dtype 

41 

42 suppressed_residues = suppressed_residues or [] 

43 

44 # NOTE: Greedy search requires `residue_set` class in the model, 

45 # update all methods accordingly. 

46 if not hasattr(model, "residue_set"): 

47 raise AttributeError("The model is missing the required attribute: residue_set") 

48 

49 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000 

50 # We would need to divide the scaled masses as we use floating point masses. 

51 # These residue masses are per amino acid and include special tokens, 

52 # special tokens have a mass of 0. 

53 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype) 

54 terminal_residues_idx: list[int] = [] 

55 suppressed_residues_idx: list[int] = [] 

56 

57 # residue_target_offsets supports negative masses (overshoot the remaining mass) 

58 # This fixes a bug where the residue prior to a negative mass residue is always invalid. 

59 residue_target_offsets: list[float] = [0.0] 

60 

61 for i, residue in enumerate(model.residue_set.vocab): 

62 if residue in self.model.residue_set.special_tokens: 

63 continue 

64 self.residue_masses[i] = self.model.residue_set.get_mass(residue) 

65 # If no residue is attached, assume it is a n-terminal residue 

66 if not residue[0].isalpha(): 

67 terminal_residues_idx.append(i) 

68 if self.residue_masses[i] < 0: 

69 residue_target_offsets.append(self.residue_masses[i]) 

70 

71 # Check if residue is suppressed 

72 if residue in suppressed_residues: 

73 suppressed_residues_idx.append(i) 

74 suppressed_residues.remove(residue) 

75 

76 if len(suppressed_residues) > 0: 

77 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}") 

78 

79 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long) 

80 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long) 

81 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype) 

82 

83 self.vocab_size = len(self.model.residue_set) 

84 

85 @classmethod 

86 def from_file(cls, model: Decodable, path: str, float_dtype: torch.dtype = torch.float64) -> KnapsackBeamSearchDecoder: 

87 """Initialize a decoder by loading a saved knapsack. 

88 

89 Args: 

90 model (Decodable): The model to be decoded from. 

91 path (str): The path to the directory where the knapsack 

92 was saved to. 

93 float_dtype (torch.dtype): The floating point dtype to use. 

94 

95 Returns: 

96 KnapsackBeamSearchDecoder: The decoder. 

97 """ 

98 knapsack = Knapsack.from_file(path=path) 

99 return cls(model=model, knapsack=knapsack, float_dtype=float_dtype) 

100 

101 def _filter_knapsack( 

102 self, 

103 remaining_mass: torch.Tensor, 

104 mass_target_delta: torch.Tensor, 

105 mass_target_expanded: torch.Tensor, 

106 ) -> torch.Tensor: 

107 self.chart = self.chart.to(remaining_mass.device) 

108 # batch_size = remaining_mass.shape[0] 

109 vocab_size = self.chart.shape[1] 

110 

111 # Step 1: Compute bounds 

112 lower_bound = (remaining_mass - mass_target_delta).unsqueeze(1) # [batch_size, vocab_size] 

113 upper_bound = (remaining_mass + mass_target_delta).unsqueeze(1) # [batch_size, vocab_size] 

114 

115 lower_bound = (lower_bound * self.mass_scale).round().long() 

116 upper_bound = (upper_bound * self.mass_scale).round().long() 

117 

118 # Step 2: Clamp to valid chart index range 

119 lower_bound = lower_bound.clamp(min=0) 

120 upper_bound = upper_bound.clamp(min=0, max=self.chart.shape[0] - 1) 

121 

122 # Step 3: Compute maximum interval width 

123 max_span = (upper_bound - lower_bound + 1).max().item() 

124 

125 # Step 4: Build mass indices 

126 span_indices = torch.arange(max_span, device=remaining_mass.device).view(1, 1, -1) # [1, 1, max_span] 

127 mass_indices = lower_bound.unsqueeze(-1) + span_indices # [batch_size, vocab_size, max_span] 

128 

129 # Step 5: Clamp again to ensure indexing is valid (some upper < lower may happen) 

130 mass_indices = mass_indices.clamp(0, self.chart.shape[0] - 1) 

131 

132 # Step 6: Create span mask 

133 span_lengths = (upper_bound - lower_bound + 1).unsqueeze(-1) # [batch_size, vocab_size, 1] 

134 span_mask = span_indices - lower_bound.unsqueeze(-1) < span_lengths # mask out padded indices 

135 

136 # Step 7: Gather and mask 

137 # chart[mass_indices, residue] = chart[mass_indices[b, r, :], r] 

138 # => gather per (mass, vocab) with advanced indexing 

139 chart_vals = self.chart[mass_indices, torch.arange(vocab_size).view(1, -1, 1)] # [batch_size, vocab_size, max_span] 

140 

141 chart_vals = chart_vals & span_mask # apply mask 

142 

143 # Step 8: Reduce 

144 return chart_vals.any(dim=-1) # [batch_size, vocab_size] 

145 

146 def decode( # type:ignore 

147 self, 

148 spectra: Float[Spectrum, " batch"], 

149 precursors: Float[PrecursorFeatures, " batch"], 

150 beam_size: int, 

151 max_length: int, 

152 mass_tolerance: float = 5e-5, 

153 max_isotope: int = 1, 

154 min_log_prob: float = -float("inf"), 

155 return_encoder_output: bool = False, 

156 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean", 

157 return_beam: bool = False, 

158 **kwargs, 

159 ) -> dict[str, Any]: 

160 """Decode predicted residue sequence for a batch of spectra using beam search. 

161 

162 Args: 

163 spectra (torch.FloatTensor): 

164 The spectra to be sequenced. 

165 

166 precursors (torch.FloatTensor[batch size, 3]): 

167 The precursor mass, charge and mass-to-charge ratio. 

168 

169 beam_size (int): 

170 The maximum size of the beam. 

171 Ignored in beam search. 

172 

173 max_length (int): 

174 The maximum length of a residue sequence. 

175 

176 mass_tolerance (float): 

177 The maximum relative error for which a predicted sequence 

178 is still considered to have matched the precursor mass. 

179 

180 max_isotope (int): 

181 The maximum number of additional neutrons for isotopes 

182 whose mass a predicted sequence's mass is considered 

183 when comparing to the precursor mass. 

184 

185 All additional nucleon numbers from 1 to `max_isotope` inclusive 

186 are considered. 

187 

188 min_log_prob (float): 

189 Minimum log probability to stop decoding early. If a sequence 

190 probability is less than this value it is marked as complete. 

191 Defaults to -inf. 

192 

193 return_beam (bool): 

194 Optionally return beam-search results. Ignored in greedy search. 

195 

196 Returns: 

197 list[list[str]]: 

198 The predicted sequence as a list of residue tokens. 

199 This method will return an empty list for each 

200 spectrum in the batch where 

201 decoding fails i.e. no sequence that fits the precursor mass 

202 to within a tolerance is found. 

203 """ 

204 # Greedy search with precursor mass termination condition 

205 batch_size = spectra.shape[0] 

206 effective_batch_size = batch_size * beam_size 

207 device = spectra.device 

208 

209 # Masses of all residues in vocabulary, 0 for special tokens 

210 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, ) 

211 

212 # ppm equivalent of mass tolerance 

213 delta_ppm_tol = mass_tolerance * 10**6 # float (1, ) 

214 

215 # Residue masses expanded (repeated) across batch_size 

216 # This is used to quickly compute all possible remaining masses per vocab entry 

217 residue_mass_delta = self.residue_masses.expand(effective_batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size) 

218 

219 # completed_items: list[list[ScoredSequence]] = [[] for _ in range(batch_size)] 

220 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)] 

221 

222 with torch.no_grad(): 

223 # 1. Compute spectrum encoding and masks 

224 # Encoder is only run once. 

225 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors) 

226 

227 # EXPAND FOR BEAM SIZE 

228 spectrum_encoding_expanded = spectrum_encoding.repeat_interleave(beam_size, dim=0) 

229 spectrum_mask_expanded = spectrum_mask.repeat_interleave(beam_size, dim=0) 

230 

231 # 2. Initialise beams and other variables 

232 # The sequences decoded so far, grows on index 1 for every decoding pass. 

233 # sequence_length is variable! 

234 sequences = torch.zeros((effective_batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length) 

235 

236 # Log probabilities of the sequences decoded so far, 

237 # token probabilities are added at each step. 

238 log_probabilities = torch.zeros((effective_batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1) 

239 

240 # Keeps track of which beams are completed, this allows the model to skip these 

241 complete_beams = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, ) 

242 is_first_complete = torch.zeros((effective_batch_size), device=device, dtype=bool) # bool (batch_size, ) 

243 

244 # Extract precursor mass from `precursors` 

245 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0) 

246 

247 precursor_mass = precursors_expanded[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, ) 

248 

249 # Target mass delta, remaining mass x must be within `target > x > -target`. 

250 # This target can shift with isotopes. 

251 # Mass targets = error_ppm * m_prec * 1e-6 

252 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, ) 

253 

254 # This keeps track of the remaining mass budget for the currently decoding sequence, 

255 # starts at the precursor - H2O 

256 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, ) 

257 

258 # TODO: only check when close to precursor mass? Might not be worth the overhead. 

259 # Idea is if remaining < check_zone, we do the valid mass and complete checks. 

260 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta 

261 

262 # Constant beam indices for retaining beams on failed decoding 

263 constant_beam_indices = torch.arange(beam_size, device=device)[None, :].repeat_interleave(batch_size, dim=0) 

264 

265 # Store token probabilities 

266 token_log_probabilities: dict[str, list[float]] = {} # dict[str, list[float]] 

267 

268 # Start decoding 

269 for i in range(max_length): 

270 # If all beams are complete, we can stop early. 

271 if complete_beams.all(): 

272 break 

273 

274 # Step 3: score the next tokens 

275 # NOTE: SOS token is appended automatically in `score_candidates`. 

276 # We do not have to add it. 

277 batch = (sequences, precursors_expanded, spectrum_encoding_expanded, spectrum_mask_expanded) 

278 next_token_probabilities = self.model.score_candidates(*batch) 

279 

280 # Step 4: Filter probabilities 

281 # If remaining mass is within tolerance, we force an EOS token. 

282 # All tokens that would set the remaining mass below the minimum 

283 # cutoff `-mass_target_delta` including isotopes is set to -inf 

284 

285 # Step 4.1: Check if remaining mass is within tolerance: 

286 # To keep it efficient we compute some of the indexed variables first: 

287 remaining_meets_precursor = torch.zeros((effective_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, ) 

288 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive) 

289 for j in range(0, max_isotope + 1, 1): 

290 # TODO: Use vectorized approach for this 

291 isotope = CARBON_MASS_DELTA * j # float 

292 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta # bool (sub_batch_size, ) 

293 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta # bool (sub_batch_size, ) 

294 

295 # remaining mass is within the target tolerance 

296 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, ) 

297 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, ) 

298 if remaining_within_range.any() and j > 0: 

299 # If we did hit an isotope, correct the remaining mass accordingly 

300 # TODO check this 

301 remaining_mass[remaining_within_range] = remaining_mass[remaining_within_range] - isotope 

302 

303 # Step 4.2: Check which residues are valid 

304 # Expand incomplete remaining mass across vocabulary size 

305 remaining_mass_expanded = remaining_mass[:, None].expand( 

306 effective_batch_size, self.vocab_size 

307 ) # float64 (effective_batch_size, vocab_size) 

308 mass_target_expanded = mass_target_delta[:, None].expand( 

309 effective_batch_size, self.vocab_size 

310 ) # float64 (effective_batch_size, vocab_size) 

311 

312 valid_mass = remaining_mass_expanded - residue_mass_delta > -mass_target_expanded # bool (effective_batch_size, vocab_size) 

313 # Check all isotopes for valid masses 

314 # TODO: Use vectorized approach for this 

315 for mass_offset in self.residue_target_offsets: 

316 for j in range(0, max_isotope + 1, 1): 

317 isotope = CARBON_MASS_DELTA * j # float 

318 mass_lesser_isotope = ( 

319 remaining_mass_expanded - residue_mass_delta < mass_target_expanded + isotope + mass_offset 

320 ) # bool (effective_batch_size, vocab_size) 

321 mass_greater_isotope = ( 

322 remaining_mass_expanded - residue_mass_delta > -mass_target_expanded + isotope + mass_offset 

323 ) # bool (effective_batch_size, vocab_size) 

324 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (effective_batch_size, vocab_size) 

325 

326 # Filter using knapsack 

327 knapsack_valid_mass = self._filter_knapsack( 

328 remaining_mass=remaining_mass, 

329 mass_target_delta=mass_target_delta, 

330 mass_target_expanded=mass_target_expanded, 

331 ) 

332 # knapsack_valid_mass = torch.zeros_like(valid_mass) 

333 # for mass_offset in self.residue_target_offsets: 

334 # for j in range(0, max_isotope + 1, 1): 

335 # isotope = CARBON_MASS_DELTA * j # float 

336 # knapsack_valid_mass = knapsack_valid_mass | self._filter_knapsack( 

337 # remaining_mass=remaining_mass - isotope - mass_offset, 

338 # mass_target_delta=mass_target_delta, 

339 # mass_target_expanded=mass_target_expanded, 

340 # ) 

341 

342 valid_mass = valid_mass & knapsack_valid_mass 

343 

344 # Filtered probabilities: 

345 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (effective_batch_size, vocab_size) 

346 # If mass is invalid, set log_prob to -inf 

347 next_token_probabilities_filtered[~valid_mass] = -float("inf") 

348 

349 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf") 

350 # Allow the model to predict PAD when all residues are -inf 

351 next_token_probabilities_filtered[:, self.model.residue_set.PAD_INDEX] = -float("inf") 

352 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf") 

353 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf") 

354 # Set probability of n-terminal modifications to -inf when i > 0 

355 if self.disable_terminal_residues_anywhere: 

356 # Check if adding terminal residues would result in a complete sequence 

357 # First generate remaining mass matrix with isotopes 

358 remaining_mass_isotope = remaining_mass[:, None].expand(effective_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * ( 

359 torch.arange(max_isotope + 1, device=device) 

360 ) 

361 # Expand with terminal residues and subtract 

362 remaining_mass_isotope_delta = ( 

363 remaining_mass_isotope[:, :, None].expand( 

364 effective_batch_size, 

365 max_isotope + 1, 

366 self.terminal_residue_indices.shape[0], 

367 ) 

368 - self.residue_masses[self.terminal_residue_indices] 

369 ) 

370 

371 # If within target delta, allow these residues to be predicted, 

372 # otherwise set probability to -inf 

373 allow_terminal = (remaining_mass_isotope_delta.abs() < mass_target_delta[:, None, None]).any(dim=1) 

374 allow_terminal_full = torch.ones( 

375 (effective_batch_size, self.vocab_size), 

376 device=spectra.device, 

377 dtype=bool, 

378 ) 

379 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal 

380 

381 # Set to -inf 

382 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf") 

383 

384 # This doesn't need to go here, we can do it at the end 

385 if is_first_complete.any(): 

386 completed_idxs = is_first_complete.nonzero().squeeze(-1) 

387 for beam_idx in completed_idxs: 

388 sequence_probability = ( 

389 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx, self.model.residue_set.EOS_INDEX] 

390 ) 

391 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()]) 

392 sequence = self.model.decode(sequences[beam_idx]) 

393 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]} 

394 if "".join(sequence) in seen_completed_sequences: 

395 continue 

396 completed_beams[beam_idx // beam_size].append( 

397 { 

398 "predictions": sequence, 

399 "mass_error": remaining_mass[beam_idx].item(), 

400 "meets_precursor": remaining_meets_precursor[beam_idx].item(), 

401 "prediction_log_probability": sequence_probability.item(), 

402 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

403 } 

404 ) 

405 

406 # For beams that already meet precursor, -inf them and force an EOS 

407 next_token_probabilities_filtered[remaining_meets_precursor, :] = -float("inf") 

408 if self.keep_invalid_mass_sequences: 

409 # Allow EOS on beams that dont fit precursor 

410 next_beam_no_predictions = next_token_probabilities_filtered.isinf().all(-1) 

411 allow_eos = (remaining_meets_precursor | next_beam_no_predictions) & ~complete_beams 

412 else: 

413 allow_eos = (remaining_meets_precursor) & ~complete_beams 

414 next_eos_probs = next_token_probabilities[allow_eos, self.model.residue_set.EOS_INDEX] 

415 next_token_probabilities_filtered[allow_eos, self.model.residue_set.EOS_INDEX] = next_eos_probs 

416 

417 # Step 5: Select next token: 

418 log_probabilities_expanded = log_probabilities.repeat_interleave(self.vocab_size, dim=1) 

419 log_probabilities_expanded = log_probabilities_expanded + next_token_probabilities_filtered 

420 

421 log_probabilities_beams = log_probabilities_expanded.view(-1, beam_size, self.vocab_size) 

422 if i == 0 and beam_size > 1: 

423 # Nullify all beams except the first one 

424 # Forces divergence of beams at the start 

425 log_probabilities_beams[:, 1:] = -float("inf") 

426 

427 log_probabilities_beams = log_probabilities_beams.view(-1, beam_size * self.vocab_size) 

428 

429 topk_values, topk_indices = log_probabilities_beams.topk(beam_size, dim=-1) 

430 topk_is_inf = topk_values.isinf() 

431 beam_indices = topk_indices // self.vocab_size 

432 # Retain beams on failed decoding (when all beams are -inf) 

433 beam_indices[topk_is_inf] = constant_beam_indices[topk_is_inf] 

434 beam_indices_full = (beam_indices + torch.arange(batch_size, device=beam_indices.device)[:, None] * beam_size).view(-1) 

435 

436 next_token = topk_indices % self.vocab_size 

437 next_token[topk_is_inf] = self.model.residue_set.PAD_INDEX 

438 

439 next_token = next_token.view(-1, 1) # long (sub_batch_size, 1) 

440 

441 # Update beams by indices 

442 sequences = sequences[beam_indices_full] 

443 log_probabilities = log_probabilities[beam_indices_full] 

444 next_token_probabilities = next_token_probabilities[beam_indices_full] 

445 remaining_mass = remaining_mass[beam_indices_full] 

446 complete_beams = complete_beams[beam_indices_full] 

447 

448 sequences = torch.concat([sequences, next_token], axis=1) # long (batch_size, 1) 

449 

450 # Expand and update masses 

451 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, ) 

452 remaining_mass = remaining_mass - next_masses # float64 (batch_size, ) 

453 

454 # Expand and update probabilities 

455 next_token_probabilities[:, self.model.residue_set.PAD_INDEX] = 0 

456 next_probabilities = torch.gather(next_token_probabilities, 1, next_token) 

457 next_probabilities[complete_beams] = 0 

458 

459 log_probabilities = log_probabilities + next_probabilities 

460 

461 for batch_index in range(effective_batch_size): 

462 # Create unique ID for the sequence 

463 # Store beam token probabilities in a hash table 

464 spectrum_index = batch_index // beam_size 

465 sequence = [str(x) for x in sequences[batch_index].cpu().tolist()] 

466 sequence_str = str(spectrum_index) + "-" + ".".join(sequence) 

467 sequence_prev_str = str(spectrum_index) + "-" + ".".join(sequence[:-1]) 

468 

469 if sequence_prev_str in token_log_probabilities: 

470 previous_probabilities = list(token_log_probabilities[sequence_prev_str]) 

471 else: 

472 previous_probabilities = [] 

473 

474 previous_probabilities.append(next_probabilities[batch_index, 0].float().item()) 

475 

476 token_log_probabilities[sequence_str] = previous_probabilities 

477 

478 # Step 6: Terminate complete beams 

479 

480 # Check if complete: 

481 # Early stopping if beam log probability below threshold 

482 beam_confidence_filter = log_probabilities[:, 0] < min_log_prob 

483 # Stop if beam is forced to output an EOS 

484 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX 

485 next_token_is_pad = next_token[:, 0] == self.model.residue_set.PAD_INDEX 

486 next_is_complete = next_token_is_eos | beam_confidence_filter # | next_token_is_pad 

487 

488 # Keep track of which beams have completed 

489 # Beams that complete for the first time are added to completed_beams 

490 complete_beams = complete_beams | next_is_complete 

491 is_first_complete = next_is_complete 

492 

493 if next_token_is_pad.all(): 

494 break 

495 

496 # Repeat from step 3. 

497 

498 # Check if any beams are complete at the end of the loop 

499 if is_first_complete.any(): 

500 completed_idxs = is_first_complete.nonzero().squeeze(-1) 

501 for beam_idx in completed_idxs: 

502 sequence_probability = ( 

503 log_probabilities[beam_idx] # + next_token_probabilities[beam_idx, 

504 # self.model.residue_set.EOS_INDEX] 

505 ) 

506 sequence_str = str((beam_idx // beam_size).item()) + "-" + ".".join([str(x) for x in sequences[beam_idx].cpu().tolist()]) 

507 sequence = self.model.decode(sequences[beam_idx]) 

508 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[beam_idx // beam_size]} 

509 if "".join(sequence) in seen_completed_sequences: 

510 continue 

511 completed_beams[beam_idx // beam_size].append( 

512 { 

513 "predictions": sequence, 

514 "mass_error": remaining_mass[beam_idx].item(), 

515 "meets_precursor": remaining_meets_precursor[beam_idx].item(), 

516 "prediction_log_probability": sequence_probability.item(), 

517 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

518 } 

519 ) 

520 

521 # This loop forcefully adds all beams at the end, whether they are complete or not 

522 if self.keep_invalid_mass_sequences: 

523 for batch_idx in range(effective_batch_size): 

524 sequence_str = str(batch_idx // beam_size) + "-" + ".".join([str(x) for x in sequences[batch_idx].cpu().tolist()]) 

525 sequence = self.model.decode(sequences[batch_idx]) 

526 seen_completed_sequences = {"".join(x["predictions"]) for x in completed_beams[batch_idx // beam_size]} 

527 if "".join(sequence) in seen_completed_sequences: 

528 # print(f"Skipping {sequence_str} because it is added") 

529 continue 

530 completed_beams[batch_idx // beam_size].append( 

531 { 

532 "predictions": sequence, 

533 "mass_error": remaining_mass[batch_idx].item(), 

534 "meets_precursor": remaining_meets_precursor[batch_idx].item(), 

535 "prediction_log_probability": log_probabilities[batch_idx, 0].item(), 

536 "prediction_token_log_probabilities": token_log_probabilities[sequence_str][: len(sequence)][::-1], 

537 } 

538 ) 

539 

540 # Get top n beams per batch 

541 # Filtered on meets_precursor and log_probability 

542 top_completed_beams = self._get_top_n_beams(completed_beams, beam_size) 

543 

544 # Prepare result dictionary 

545 result: dict[str, Any] = { 

546 "predictions": [], 

547 # "mass_error": [], 

548 "prediction_log_probability": [], 

549 "prediction_token_log_probabilities": [], 

550 } 

551 if return_beam: 

552 for i in range(beam_size): 

553 result[f"predictions_beam_{i}"] = [] 

554 # result[f"mass_error_beam_{i}"] = [] 

555 result[f"predictions_log_probability_beam_{i}"] = [] 

556 result[f"predictions_token_log_probabilities_beam_{i}"] = [] 

557 

558 for batch_idx in range(batch_size): 

559 if return_beam: 

560 for beam_idx in range(beam_size): 

561 result[f"predictions_beam_{beam_idx}"].append("".join(top_completed_beams[batch_idx][beam_idx]["predictions"])) 

562 # result[f"mass_error_beam_{beam_idx}"].append(top_completed_beams[batch_idx][beam_idx]["mass_error"]) 

563 result[f"predictions_log_probability_beam_{beam_idx}"].append( 

564 top_completed_beams[batch_idx][beam_idx]["prediction_log_probability"] 

565 ) 

566 result[f"predictions_token_log_probabilities_beam_{beam_idx}"].append( 

567 top_completed_beams[batch_idx][beam_idx]["prediction_token_log_probabilities"] 

568 ) 

569 

570 # Save best beam as main result 

571 result["predictions"].append(top_completed_beams[batch_idx][0]["predictions"]) 

572 # result[f"mass_error"].append(top_completed_beams[batch_idx][0]["mass_error"]) 

573 result["prediction_log_probability"].append(top_completed_beams[batch_idx][0]["prediction_log_probability"]) 

574 result["prediction_token_log_probabilities"].append(top_completed_beams[batch_idx][0]["prediction_token_log_probabilities"]) 

575 

576 # Optionally include encoder output 

577 if return_encoder_output: 

578 # Reduce along sequence length dimension 

579 encoder_output = spectrum_encoding.float().cpu() 

580 encoder_mask = (1 - spectrum_mask.float()).cpu() 

581 encoder_output = encoder_output * encoder_mask.unsqueeze(-1) 

582 if encoder_output_reduction == "mean": 

583 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1) 

584 encoder_output = encoder_output.sum(dim=1) / count 

585 elif encoder_output_reduction == "max": 

586 encoder_output[encoder_output == 0] = -float("inf") 

587 encoder_output = encoder_output.max(dim=1)[0] 

588 elif encoder_output_reduction == "sum": 

589 encoder_output = encoder_output.sum(dim=1) 

590 elif encoder_output_reduction == "full": 

591 raise NotImplementedError("Full encoder output reduction is not yet implemented") 

592 else: 

593 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}") 

594 result["encoder_output"] = list(encoder_output.numpy()) 

595 

596 return result 

597 

598 def _get_top_n_beams(self, completed_beams: list[list[dict[str, Any]]], beam_size: int) -> list[list[dict[str, Any]]]: 

599 """Get the top n beams from the completed beams. 

600 

601 Args: 

602 completed_beams: The completed beams to get the top n beams from. 

603 Each beam is a dictionary with the following keys: 

604 - predictions: The predictions of the beam. 

605 - mass_error: The mass error of the beam. 

606 - meets_precursor: Whether the beam meets the precursor mass. 

607 - prediction_log_probability: The log probability of the beam. 

608 - prediction_token_log_probabilities: The log probabilities of the tokens in the beam. 

609 beam_size: The number of beams to keep per batch. 

610 

611 Returns: 

612 A list of lists, each containing the top n beams for a batch. 

613 """ 

614 default_beam = { 

615 "predictions": [], 

616 "mass_error": -float("inf"), 

617 "prediction_log_probability": -float("inf"), 

618 "prediction_token_log_probabilities": [], 

619 } 

620 

621 top_beams_per_row = [] 

622 for beams in completed_beams: 

623 # Sort first by error within tolerance, then by log_prob descending 

624 beams.sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True) 

625 

626 # Keep top N beams 

627 top_beams = beams[:beam_size] 

628 

629 # Pad with default beam if fewer than N 

630 while len(top_beams) < beam_size: 

631 top_beams.append(default_beam.copy()) 

632 

633 top_beams_per_row.append(top_beams) 

634 

635 return top_beams_per_row