Coverage for instanovo/inference/greedy_search.py: 85%

137 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Literal 

4 

5import torch 

6from jaxtyping import Float 

7 

8from instanovo.__init__ import console 

9from instanovo.constants import CARBON_MASS_DELTA, H2O_MASS, MASS_SCALE, PrecursorDimension 

10from instanovo.inference.interfaces import Decodable, Decoder 

11from instanovo.types import PrecursorFeatures, Spectrum 

12from instanovo.utils.colorlogging import ColorLog 

13 

14logger = ColorLog(console, __name__).logger 

15 

16 

17class GreedyDecoder(Decoder): 

18 """A class for decoding from de novo sequence models using greedy search. 

19 

20 This class conforms to the `Decoder` interface and decodes from 

21 models that conform to the `Decodable` interface. 

22 """ 

23 

24 def __init__( 

25 self, 

26 model: Decodable, 

27 suppressed_residues: list[str] | None = None, 

28 mass_scale: int = MASS_SCALE, 

29 disable_terminal_residues_anywhere: bool = True, 

30 float_dtype: torch.dtype = torch.float64, 

31 ): 

32 super().__init__(model=model) 

33 self.mass_scale = mass_scale 

34 self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere 

35 self.float_dtype = float_dtype 

36 

37 suppressed_residues = suppressed_residues or [] 

38 

39 # NOTE: Greedy search requires `residue_set` class in the model, 

40 # update all methods accordingly. 

41 if not hasattr(model, "residue_set"): 

42 raise AttributeError("The model is missing the required attribute: residue_set") 

43 

44 # TODO: Check if this can be replaced with model.get_residue_masses(mass_scale=10000)/10000 

45 # We would need to divide the scaled masses as we use floating point masses. 

46 # These residue masses are per amino acid and include special tokens, 

47 # special tokens have a mass of 0. 

48 self.residue_masses = torch.zeros((len(self.model.residue_set),), dtype=self.float_dtype) 

49 terminal_residues_idx: list[int] = [] 

50 suppressed_residues_idx: list[int] = [] 

51 

52 # residue_target_offsets supports negative masses (overshoot the remaining mass) 

53 # This fixes a bug where the residue prior to a negative mass residue is always invalid. 

54 residue_target_offsets: list[float] = [0.0] 

55 

56 for i, residue in enumerate(model.residue_set.vocab): 

57 if residue in self.model.residue_set.special_tokens: 

58 continue 

59 self.residue_masses[i] = self.model.residue_set.get_mass(residue) 

60 # If no residue is attached, assume it is a n-terminal residue 

61 if not residue[0].isalpha(): 

62 terminal_residues_idx.append(i) 

63 if self.residue_masses[i] < 0: 

64 residue_target_offsets.append(self.residue_masses[i]) 

65 

66 # Check if residue is suppressed 

67 if residue in suppressed_residues: 

68 suppressed_residues_idx.append(i) 

69 suppressed_residues.remove(residue) 

70 

71 if len(suppressed_residues) > 0: 

72 logger.warning(f"Some suppressed residues not found in vocabulary: {suppressed_residues}") 

73 

74 self.terminal_residue_indices = torch.tensor(terminal_residues_idx, dtype=torch.long) 

75 self.suppressed_residue_indices = torch.tensor(suppressed_residues_idx, dtype=torch.long) 

76 self.residue_target_offsets = torch.tensor(residue_target_offsets, dtype=self.float_dtype) 

77 

78 self.vocab_size = len(self.model.residue_set) 

79 

80 def decode( # type:ignore 

81 self, 

82 spectra: Float[Spectrum, " batch"], 

83 precursors: Float[PrecursorFeatures, " batch"], 

84 max_length: int, 

85 mass_tolerance: float = 5e-5, 

86 max_isotope: int = 1, 

87 min_log_prob: float = -float("inf"), 

88 return_encoder_output: bool = False, 

89 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean", 

90 **kwargs, 

91 ) -> dict[str, Any]: 

92 """Decode predicted residue sequence for a batch of spectra using greedy search. 

93 

94 Args: 

95 spectra (torch.FloatTensor): 

96 The spectra to be sequenced. 

97 

98 precursors (torch.FloatTensor[batch size, 3]): 

99 The precursor mass, charge and mass-to-charge ratio. 

100 

101 max_length (int): 

102 The maximum length of a residue sequence. 

103 

104 mass_tolerance (float): 

105 The maximum relative error for which a predicted sequence 

106 is still considered to have matched the precursor mass. 

107 

108 max_isotope (int): 

109 The maximum number of additional neutrons for isotopes 

110 whose mass a predicted sequence's mass is considered 

111 when comparing to the precursor mass. 

112 

113 All additional nucleon numbers from 1 to `max_isotope` inclusive 

114 are considered. 

115 

116 min_log_prob (float): 

117 Minimum log probability to stop decoding early. If a sequence 

118 probability is less than this value it is marked as complete. 

119 Defaults to -inf. 

120 

121 return_encoder_output: 

122 Whether to return the encoder output. 

123 

124 encoder_output_reduction: 

125 The reduction to apply to the encoder output. 

126 Valid values are "mean", "max", "sum", "full". 

127 Defaults to "mean". 

128 

129 Returns: 

130 dict[str, Any]: 

131 Required keys: 

132 - "predictions": list[list[str]] 

133 - "mass_error": list[float] 

134 - "prediction_log_probability": list[float] 

135 - "prediction_token_log_probabilities": list[list[float]] 

136 - "encoder_output": list[float] (optional) 

137 Example additional keys: 

138 - "prediction_beam_0": list[str] 

139 """ 

140 # Greedy search with precursor mass termination condition 

141 batch_size = spectra.shape[0] 

142 device = spectra.device 

143 

144 # Masses of all residues in vocabulary, 0 for special tokens 

145 self.residue_masses = self.residue_masses.to(spectra.device) # float32 (vocab_size, ) 

146 

147 # ppm equivalent of mass tolerance 

148 delta_ppm_tol = mass_tolerance * 10**6 # float (1, ) 

149 

150 # Residue masses expanded (repeated) across batch_size 

151 # This is used to quickly compute all possible remaining masses per vocab entry 

152 residue_mass_delta = self.residue_masses.expand(batch_size, self.residue_masses.shape[0]) # float32 (batch_size, vocab_size) 

153 

154 with torch.no_grad(): 

155 # 1. Compute spectrum encoding and masks 

156 # Encoder is only run once. 

157 (spectrum_encoding, spectrum_mask), _ = self.model.init(spectra, precursors) 

158 

159 # 2. Initialise beams and other variables 

160 # The sequences decoded so far, grows on index 1 for every decoding pass. 

161 # sequence_length is variable! 

162 sequences = torch.zeros((batch_size, 0), device=device, dtype=torch.long) # long (batch_size, sequence_length) 

163 

164 # Log probabilities of the sequences decoded so far, 

165 # token probabilities are added at each step. 

166 log_probabilities = torch.zeros((batch_size, 1), device=device, dtype=torch.float32) # long (batch_size, 1) 

167 

168 # Keeps track of which beams are completed, this allows the model to skip these 

169 complete_beams = torch.zeros((batch_size), device=device, dtype=bool) # bool (batch_size, ) 

170 

171 # Keeps track of which stopped early or terminated with a bad stop condition. 

172 # These predictions will be deleted. 

173 # bad_stop_condition = 

174 # torch.zeros((batch_size), device=device, dtype=bool) # bool (batch_size, ) 

175 # Extract precursor mass from `precursors` 

176 precursor_mass = precursors[:, PrecursorDimension.PRECURSOR_MASS.value] # float32 (batch_size, ) 

177 

178 # Target mass delta, remaining mass x must be within `target > x > -target`. 

179 # This target can shift with isotopes. 

180 # Mass targets = error_ppm * m_prec * 1e-6 

181 mass_target_delta = delta_ppm_tol * precursor_mass.to(self.float_dtype) * 1e-6 # float_dtype (batch_size, ) 

182 

183 # This keeps track of the remaining mass budget for the currently decoding sequence, 

184 # starts at the precursor - H2O 

185 remaining_mass = precursor_mass.to(self.float_dtype) - H2O_MASS # float_dtype (batch_size, ) 

186 

187 # TODO: only check when close to precursor mass? Might not be worth the overhead. 

188 # Idea is if remaining < check_zone, we do the valid mass and complete checks. 

189 # check_zone = self.residue_masses.max().expand(batch_size) + mass_target_delta 

190 

191 # Store token probabilities 

192 token_log_probabilities = [] # list[list(float)] (sequence_length, batch_size) 

193 

194 # Start decoding 

195 for _ in range(max_length): 

196 # If all beams are complete, we can stop early. 

197 if complete_beams.all(): 

198 break 

199 

200 # We only run the model on incomplete beams, note: we have to expand to 

201 # the full batch size afterwards. 

202 minibatch = (x[~complete_beams] for x in (sequences, precursors, spectrum_encoding, spectrum_mask)) 

203 # Keep track of how large the minibatch is 

204 sub_batch_size = (~complete_beams).sum() 

205 

206 # Step 3: score the next tokens 

207 # NOTE: SOS token is appended automatically in `score_candidates`. 

208 # We do not have to add it. 

209 next_token_probabilities = self.model.score_candidates(*minibatch) 

210 

211 # Step 4: Filter probabilities 

212 # If remaining mass is within tolerance, we force an EOS token. 

213 # All tokens that would set the remaining mass below the minimum 

214 # cutoff `-mass_target_delta` including isotopes is set to -inf 

215 

216 # Step 4.1: Check if remaining mass is within tolerance: 

217 # To keep it efficient we compute some of the indexed variables first: 

218 remaining_mass_incomplete = remaining_mass[~complete_beams] # float64 (sub_batch_size, ) 

219 mass_target_incomplete = mass_target_delta[~complete_beams] # float64 (sub_batch_size, ) 

220 

221 # remaining_meets_precursor = 

222 # (remaining_mass[~complete_beams] < mass_target_delta[~complete_beams]) 

223 remaining_meets_precursor = torch.zeros((sub_batch_size,), device=device, dtype=bool) # bool (sub_batch_size, ) 

224 # This loop checks if mass is within tolerance for 0 to max_isotopes (inclusive) 

225 for j in range(0, max_isotope + 1, 1): 

226 # TODO: Use vectorized approach for this 

227 isotope = CARBON_MASS_DELTA * j # float 

228 remaining_lesser_isotope = remaining_mass_incomplete - isotope < mass_target_incomplete # bool (sub_batch_size, ) 

229 remaining_greater_isotope = remaining_mass_incomplete - isotope > -mass_target_incomplete # bool (sub_batch_size, ) 

230 

231 # remaining mass is within the target tolerance 

232 remaining_within_range = remaining_lesser_isotope & remaining_greater_isotope # bool (sub_batch_size, ) 

233 remaining_meets_precursor = remaining_meets_precursor | remaining_within_range # bool (sub_batch_size, ) 

234 if remaining_within_range.any() and j > 0: 

235 # If we did hit an isotope, correct the remaining mass accordingly 

236 # TODO check this 

237 remaining_mass_incomplete[remaining_within_range] = remaining_mass_incomplete[remaining_within_range] - isotope 

238 

239 # Step 4.2: Check which residues are valid 

240 # Expand incomplete remaining mass across vocabulary size 

241 remaining_mass_incomplete_expanded = remaining_mass_incomplete[:, None].expand( 

242 sub_batch_size, self.vocab_size 

243 ) # float64 (sub_batch_size, vocab_size) 

244 mass_target_incomplete_expanded = mass_target_incomplete[:, None].expand( 

245 sub_batch_size, self.vocab_size 

246 ) # float64 (sub_batch_size, vocab_size) 

247 residue_mass_delta_incomplete = residue_mass_delta[~complete_beams] # float64 (sub_batch_size, vocab_size) 

248 

249 valid_mass = ( 

250 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete > -mass_target_incomplete_expanded 

251 ) # bool (sub_batch_size, vocab_size) 

252 # Check all isotopes for valid masses 

253 # TODO: Use vectorized approach for this 

254 for mass_offset in self.residue_target_offsets: 

255 for j in range(0, max_isotope + 1, 1): 

256 isotope = CARBON_MASS_DELTA * j # float 

257 mass_lesser_isotope = ( 

258 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete 

259 < mass_target_incomplete_expanded + isotope + mass_offset 

260 ) # bool (sub_batch_size, vocab_size) 

261 mass_greater_isotope = ( 

262 remaining_mass_incomplete_expanded - residue_mass_delta_incomplete 

263 > -mass_target_incomplete_expanded + isotope + mass_offset 

264 ) # bool (sub_batch_size, vocab_size) 

265 valid_mass = valid_mass | (mass_lesser_isotope & mass_greater_isotope) # bool (sub_batch_size, vocab_size) 

266 

267 # Filtered probabilities: 

268 next_token_probabilities_filtered = next_token_probabilities.clone() # float32 (sub_batch_size, vocab_size) 

269 # If mass is invalid, set log_prob to -inf 

270 next_token_probabilities_filtered[~valid_mass] = -float("inf") 

271 next_token_probabilities_filtered[:, self.model.residue_set.EOS_INDEX] = -float("inf") 

272 # Allow the model to predict PAD when all residues are -inf 

273 # next_token_probabilities_filtered[ 

274 # :, self.model.residue_set.PAD_INDEX 

275 # ] = -float("inf") 

276 next_token_probabilities_filtered[:, self.model.residue_set.SOS_INDEX] = -float("inf") 

277 next_token_probabilities_filtered[:, self.suppressed_residue_indices] = -float("inf") 

278 # Set probability of n-terminal modifications to -inf when i > 0 

279 if self.disable_terminal_residues_anywhere: 

280 # Check if adding terminal residues would result in a complete sequence 

281 # First generate remaining mass matrix with isotopes 

282 remaining_mass_incomplete_isotope = remaining_mass_incomplete[:, None].expand( 

283 sub_batch_size, max_isotope + 1 

284 ) - CARBON_MASS_DELTA * (torch.arange(max_isotope + 1, device=device)) 

285 # Expand with terminal residues and subtract 

286 remaining_mass_incomplete_isotope_delta = ( 

287 remaining_mass_incomplete_isotope[:, :, None].expand( 

288 sub_batch_size, 

289 max_isotope + 1, 

290 self.terminal_residue_indices.shape[0], 

291 ) 

292 - self.residue_masses[self.terminal_residue_indices] 

293 ) 

294 

295 # If within target delta, allow these residues to be predicted, 

296 # otherwise set probability to -inf 

297 allow_terminal = (remaining_mass_incomplete_isotope_delta.abs() < mass_target_incomplete[:, None, None]).any(dim=1) 

298 allow_terminal_full = torch.ones( 

299 (sub_batch_size, self.vocab_size), 

300 device=spectra.device, 

301 dtype=bool, 

302 ) 

303 allow_terminal_full[:, self.terminal_residue_indices] = allow_terminal 

304 

305 # Set to -inf 

306 next_token_probabilities_filtered[~allow_terminal_full] = -float("inf") 

307 

308 # Step 5: Select next token: 

309 next_token = next_token_probabilities_filtered.argmax(-1).unsqueeze(1) # long (sub_batch_size, 1) 

310 next_token[remaining_meets_precursor] = self.model.residue_set.EOS_INDEX 

311 

312 # Update sequences 

313 next_token_full = torch.full( 

314 (batch_size, 1), 

315 fill_value=self.model.residue_set.PAD_INDEX, 

316 device=spectra.device, 

317 dtype=sequences.dtype, 

318 ) # long (batch_size, 1) 

319 next_token_full[~complete_beams] = next_token 

320 sequences = torch.concat([sequences, next_token_full], axis=1) # long (batch_size, 1) 

321 

322 # Expand and update masses 

323 next_masses = self.residue_masses[next_token].squeeze() # float64 (sub_batch_size, ) 

324 next_masses_full = torch.zeros((batch_size), device=spectra.device, dtype=remaining_mass.dtype) # float64 (batch_size, ) 

325 next_masses_full[~complete_beams] = next_masses 

326 remaining_mass = remaining_mass - next_masses_full # float64 (batch_size, ) 

327 

328 # Expand and update probabilities 

329 next_probabilities = torch.gather(next_token_probabilities, 1, next_token) 

330 next_probabilities_full = torch.zeros( 

331 (batch_size, 1), 

332 device=spectra.device, 

333 dtype=log_probabilities.dtype, 

334 ) 

335 next_probabilities_full[~complete_beams] = next_probabilities 

336 log_probabilities = log_probabilities + next_probabilities_full 

337 token_log_probabilities.append(next_probabilities_full[:, 0]) 

338 

339 # Step 6: Terminate complete beams 

340 

341 # Check if complete: 

342 # Early stopping if beam log probability below threshold 

343 beam_confidence_filter = log_probabilities[~complete_beams, 0] < min_log_prob 

344 # Stop if beam is forced to output an EOS 

345 next_token_is_eos = next_token[:, 0] == self.model.residue_set.EOS_INDEX 

346 next_is_complete = next_token_is_eos | beam_confidence_filter 

347 

348 # Check for a bad stop 

349 # bad_stop_condition = beam_confidence_filter 

350 # bad_stop_condition_full = torch.zeros((batch_size,), 

351 # device=spectra.device, dtype=bad_stop_condition.dtype) 

352 # bad_stop_condition_full[~complete_beams] = bad_stop_condition 

353 # bad_stop_condition = bad_stop_condition | bad_stop_condition_full 

354 

355 # Expand and update complete beams 

356 next_is_complete_full = torch.zeros((batch_size,), device=spectra.device, dtype=complete_beams.dtype) 

357 next_is_complete_full[~complete_beams] = next_is_complete 

358 complete_beams = complete_beams | next_is_complete_full 

359 

360 # Repeat from step 3. 

361 

362 all_log_probabilities = torch.stack(token_log_probabilities, axis=1) 

363 

364 # Example of new output format 

365 result: dict[str, Any] = { 

366 "predictions": [], 

367 # "mass_error": [], 

368 "prediction_log_probability": [], 

369 "prediction_token_log_probabilities": [], 

370 } 

371 

372 for i in range(batch_size): 

373 sequence = self.model.decode(sequences[i]) 

374 result["predictions"].append(sequence) 

375 # result["mass_error"].append(remaining_mass[i].item()) 

376 result["prediction_log_probability"].append(log_probabilities[i, 0].item()) 

377 result["prediction_token_log_probabilities"].append([x.cpu().item() for x in all_log_probabilities[i, : len(sequence)]][::-1]) 

378 

379 if return_encoder_output: 

380 # TODO: Refactor to use the same logic across all decoders 

381 # Reduce along sequence length dimension 

382 encoder_output = spectrum_encoding.float().cpu() 

383 encoder_mask = (1 - spectrum_mask.float()).cpu() 

384 encoder_output = encoder_output * encoder_mask.unsqueeze(-1) 

385 if encoder_output_reduction == "mean": 

386 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1) 

387 encoder_output = encoder_output.sum(dim=1) / count 

388 elif encoder_output_reduction == "max": 

389 encoder_output[encoder_output == 0] = -float("inf") 

390 encoder_output = encoder_output.max(dim=1)[0] 

391 elif encoder_output_reduction == "sum": 

392 encoder_output = encoder_output.sum(dim=1) 

393 elif encoder_output_reduction == "full": 

394 raise NotImplementedError("Full encoder output reduction is not yet implemented") 

395 else: 

396 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}") 

397 result["encoder_output"] = list(encoder_output.numpy()) 

398 

399 return result