Coverage for instanovo/transformer/model.py: 55%

264 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import json 

4import os 

5from importlib import resources 

6from pathlib import Path 

7from typing import Optional, Tuple 

8from urllib.parse import urlsplit 

9 

10import torch 

11from jaxtyping import Bool, Float, Integer 

12from omegaconf import DictConfig, OmegaConf, open_dict 

13from torch import Tensor, nn 

14 

15from instanovo.__init__ import console 

16from instanovo.constants import LEGACY_PTM_TO_UNIMOD, MAX_SEQUENCE_LENGTH 

17from instanovo.inference import Decodable 

18from instanovo.transformer.layers import ( 

19 ConvPeakEmbedding, 

20 MultiScalePeakEmbedding, 

21 PositionalEncoding, 

22) 

23from instanovo.types import ( 

24 DiscretizedMass, 

25 Peptide, 

26 PeptideMask, 

27 PrecursorFeatures, 

28 ResidueLogits, 

29 ResidueLogProbabilities, 

30 Spectrum, 

31 SpectrumEmbedding, 

32 SpectrumMask, 

33) 

34from instanovo.utils.colorlogging import ColorLog 

35from instanovo.utils.file_downloader import download_file 

36from instanovo.utils.residues import ResidueSet 

37 

38MODEL_TYPE = "transformer" 

39 

40 

41logger = ColorLog(console, __name__).logger 

42 

43 

44class InstaNovo(nn.Module, Decodable): 

45 """The Instanovo model.""" 

46 

47 def __init__( 

48 self, 

49 residue_set: ResidueSet, 

50 dim_model: int = 768, 

51 n_head: int = 16, 

52 dim_feedforward: int = 2048, 

53 encoder_layers: int = 9, 

54 decoder_layers: int = 9, 

55 dropout: float = 0.1, 

56 max_charge: int = 5, 

57 use_flash_attention: bool = False, 

58 conv_peak_encoder: bool = False, 

59 peak_embedding_dtype: torch.dtype | str = torch.float64, 

60 ) -> None: 

61 super().__init__() 

62 self._residue_set = residue_set 

63 self.vocab_size = len(residue_set) 

64 self.use_flash_attention = use_flash_attention 

65 self.conv_peak_encoder = conv_peak_encoder 

66 

67 self.latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) 

68 

69 if self.use_flash_attention: 

70 # All input spectra are padded to some max length 

71 # Pad spectrum replaces zeros in input spectra 

72 # This is for flash attention (no masks allowed) 

73 self.pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) 

74 

75 # Encoder 

76 self.peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype) 

77 if self.conv_peak_encoder: 

78 self.conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout) 

79 

80 encoder_layer = nn.TransformerEncoderLayer( 

81 d_model=dim_model, 

82 nhead=n_head, 

83 dim_feedforward=dim_feedforward, 

84 batch_first=True, 

85 dropout=0 if self.use_flash_attention else dropout, 

86 ) 

87 self.encoder = nn.TransformerEncoder( 

88 encoder_layer, 

89 num_layers=encoder_layers, 

90 # enable_nested_tensor=False, TODO: Figure out the correct way to handle this 

91 ) 

92 

93 # Decoder 

94 self.aa_embed = nn.Embedding(self.vocab_size, dim_model, padding_idx=0) 

95 

96 self.aa_pos_embed = PositionalEncoding(dim_model, dropout, max_len=MAX_SEQUENCE_LENGTH) 

97 

98 decoder_layer = nn.TransformerDecoderLayer( 

99 d_model=dim_model, 

100 nhead=n_head, 

101 dim_feedforward=dim_feedforward, 

102 batch_first=True, 

103 dropout=0 if self.use_flash_attention else dropout, 

104 ) 

105 self.decoder = nn.TransformerDecoder( 

106 decoder_layer, 

107 num_layers=decoder_layers, 

108 ) 

109 

110 self.head = nn.Linear(dim_model, self.vocab_size) 

111 self.charge_encoder = nn.Embedding(max_charge, dim_model) 

112 

113 @property 

114 def residue_set(self) -> ResidueSet: 

115 """Every model must have a `residue_set` attribute.""" 

116 return self._residue_set 

117 

118 @staticmethod 

119 def _get_causal_mask(seq_len: int, return_float: bool = False) -> PeptideMask: 

120 mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1) 

121 if return_float: 

122 return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) 

123 return ~mask.bool() 

124 

125 @staticmethod 

126 def get_pretrained() -> list[str]: 

127 """Get a list of pretrained model ids.""" 

128 # Load the models.json file 

129 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

130 models_config = json.load(f) 

131 

132 if MODEL_TYPE not in models_config: 

133 return [] 

134 

135 return list(models_config[MODEL_TYPE].keys()) 

136 

137 @classmethod 

138 def load( 

139 cls, path: str, update_residues_to_unimod: bool = True, override_config: DictConfig | dict | None = None 

140 ) -> tuple["InstaNovo", "DictConfig"]: 

141 """Load model from checkpoint path. 

142 

143 Args: 

144 path (str): Path to checkpoint file. 

145 update_residues_to_unimod (bool): Update residues to unimod, defaults to True. 

146 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None. 

147 

148 Returns: 

149 tuple[InstaNovo, DictConfig]: Tuple of model and config. 

150 """ 

151 # Add to allow list 

152 _whitelist_torch_omegaconf() 

153 ckpt = torch.load(path, map_location="cpu", weights_only=True) 

154 

155 config = ckpt["config"] 

156 

157 if override_config is not None: 

158 if not isinstance(config, DictConfig): 

159 config = OmegaConf.create(config) 

160 with open_dict(config): 

161 config.update(override_config) 

162 

163 # TODO: Remove 

164 if "state_dict" not in ckpt: 

165 ckpt["state_dict"] = ckpt["model"] 

166 

167 # check if PTL checkpoint 

168 if all(x.startswith("model") for x in ckpt["state_dict"].keys()): 

169 ckpt["state_dict"] = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} 

170 

171 if "residues" not in ckpt: 

172 # Legacy format 

173 residues = dict(config["residues"]) 

174 else: 

175 # TODO: Remove 

176 # residues = dict(ckpt["residues"].get("residues", {})) 

177 residues = ckpt["residues"] 

178 

179 if update_residues_to_unimod: 

180 residues = {LEGACY_PTM_TO_UNIMOD[k] if k in LEGACY_PTM_TO_UNIMOD else k: v for k, v in residues.items()} 

181 

182 residue_set = ResidueSet(residues) 

183 

184 model = cls( 

185 residue_set=residue_set, 

186 dim_model=config["dim_model"], 

187 n_head=config["n_head"], 

188 dim_feedforward=config["dim_feedforward"], 

189 encoder_layers=config.get("encoder_layers", config.get("n_layers", 9)), 

190 decoder_layers=config.get("decoder_layers", config.get("n_layers", 9)), 

191 dropout=config["dropout"], 

192 max_charge=config["max_charge"], 

193 use_flash_attention=config.get("use_flash_attention", False), 

194 conv_peak_encoder=config.get("conv_peak_encoder", False), 

195 peak_embedding_dtype=config.get("peak_embedding_dtype", torch.float64), 

196 ) 

197 model.load_state_dict(ckpt["state_dict"]) 

198 

199 return model, config 

200 

201 @classmethod 

202 def from_pretrained( 

203 cls, model_id: str, update_residues_to_unimod: bool = True, override_config: DictConfig | dict | None = None 

204 ) -> tuple["InstaNovo", "DictConfig"]: 

205 """Download and load by model id or model path. 

206 

207 Args: 

208 model_id (str): Model id or model path. 

209 update_residues_to_unimod (bool): Update residues to unimod, defaults to True. 

210 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None. 

211 

212 Returns: 

213 tuple[InstaNovo, DictConfig]: Tuple of model and config. 

214 """ 

215 # TODO Refactor to use across methods 

216 # Check if model_id is a local file path 

217 if "/" in model_id or "\\" in model_id or model_id.endswith(".ckpt"): 

218 if os.path.isfile(model_id): 

219 return cls.load(model_id, update_residues_to_unimod=update_residues_to_unimod, override_config=override_config) 

220 else: 

221 raise FileNotFoundError(f"No file found at path: {model_id}") 

222 

223 # Load the models.json file 

224 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

225 models_config = json.load(f) 

226 

227 # Find the model in the config 

228 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]: 

229 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]") 

230 

231 model_info = models_config[MODEL_TYPE][model_id] 

232 url = model_info["remote"] 

233 

234 # Create cache directory if it doesn't exist 

235 cache_dir = Path.home() / ".cache" / "instanovo" 

236 cache_dir.mkdir(parents=True, exist_ok=True) 

237 

238 # Generate a filename for the cached model 

239 file_name = urlsplit(url).path.split("/")[-1] 

240 cached_file = cache_dir / file_name 

241 

242 # Check if the file is already cached 

243 if not cached_file.exists(): 

244 download_file(url, cached_file, model_id, file_name) 

245 

246 else: 

247 logger.info(f"Model {model_id} already cached at {cached_file}") 

248 

249 try: 

250 # Load and return the model 

251 logger.info(f"Loading model {model_id} (remote)") 

252 return cls.load(str(cached_file), update_residues_to_unimod=update_residues_to_unimod, override_config=override_config) 

253 except Exception as e: 

254 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}") 

255 if cached_file.exists(): 

256 cached_file.unlink() 

257 

258 download_file(url, cached_file, model_id, file_name) 

259 logger.info(f"Loading newly downloaded model {model_id}") 

260 return cls.load(str(cached_file), update_residues_to_unimod=update_residues_to_unimod, override_config=override_config) 

261 

262 def forward( 

263 self, 

264 x: Float[Spectrum, " batch"], 

265 p: Float[PrecursorFeatures, " batch"], 

266 y: Integer[Peptide, " batch"], 

267 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None, 

268 y_mask: Optional[Bool[PeptideMask, " batch"]] = None, 

269 add_bos: bool = True, 

270 return_encoder_output: bool = False, 

271 ) -> Float[ResidueLogits, "batch token+1"]: 

272 """Model forward pass. 

273 

274 Args: 

275 x: Spectra, float Tensor (batch, n_peaks, 2) 

276 p: Precursors, float Tensor (batch, 3) 

277 y: Peptide, long Tensor (batch, seq_len, vocab) 

278 x_mask: Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks) 

279 y_mask: Peptide padding mask, bool Tensor (batch, seq_len) 

280 add_bos: Force add a <s> prefix to y, bool 

281 

282 Returns: 

283 logits: float Tensor (batch, n, vocab_size), 

284 (batch, n+1, vocab_size) if add_bos==True. 

285 """ 

286 if self.use_flash_attention: 

287 x, x_mask = self._flash_encoder(x, p, x_mask) 

288 return self._flash_decoder(x, y, x_mask, y_mask, add_bos) 

289 

290 x, x_mask = self._encoder(x, p, x_mask) 

291 y = self._decoder(x, y, x_mask, y_mask, add_bos) 

292 if return_encoder_output: 

293 return y, x 

294 return y 

295 

296 def init( 

297 self, 

298 spectra: Float[Spectrum, " batch"], 

299 precursors: Float[PrecursorFeatures, " batch"], 

300 spectra_mask: Optional[Bool[SpectrumMask, " batch"]] = None, 

301 ) -> Tuple[ 

302 Tuple[Float[Spectrum, " batch"], Bool[SpectrumMask, " batch"]], 

303 Float[ResidueLogProbabilities, "batch token"], 

304 ]: 

305 """Initialise model encoder.""" 

306 if self.use_flash_attention: 

307 spectra, _ = self._encoder(spectra, precursors, None) 

308 logits = self._decoder(spectra, None, None, None, add_bos=False) 

309 return ( 

310 spectra, 

311 torch.zeros(spectra.shape[0], spectra.shape[1]).to(spectra.device), 

312 ), torch.log_softmax(logits[:, -1, :], -1) 

313 

314 spectra, spectra_mask = self._encoder(spectra, precursors, spectra_mask) 

315 logits = self._decoder(spectra, None, spectra_mask, None, add_bos=False) 

316 return (spectra, spectra_mask), torch.log_softmax(logits[:, -1, :], -1) 

317 

318 def score_candidates( 

319 self, 

320 sequences: Integer[Peptide, " batch"], 

321 precursor_mass_charge: Float[PrecursorFeatures, " batch"], 

322 spectra: Float[Spectrum, " batch"], 

323 spectra_mask: Bool[SpectrumMask, " batch"], 

324 ) -> Float[ResidueLogProbabilities, "batch token"]: 

325 """Score a set of candidate sequences.""" 

326 if self.use_flash_attention: 

327 logits = self._flash_decoder(spectra, sequences, None, None, add_bos=True) 

328 else: 

329 logits = self._decoder(spectra, sequences, spectra_mask, None, add_bos=True) 

330 

331 return torch.log_softmax(logits[:, -1, :], -1) 

332 

333 def get_residue_masses(self, mass_scale: int) -> Integer[DiscretizedMass, " residue"]: 

334 """Get the scaled masses of all residues.""" 

335 residue_masses = torch.zeros(len(self.residue_set), dtype=torch.int64) 

336 for index, residue in self.residue_set.index_to_residue.items(): 

337 if residue in self.residue_set.residue_masses: 

338 residue_masses[index] = round(mass_scale * self.residue_set.get_mass(residue)) 

339 return residue_masses 

340 

341 def get_eos_index(self) -> int: 

342 """Get the EOS token ID.""" 

343 return int(self.residue_set.EOS_INDEX) 

344 

345 def get_empty_index(self) -> int: 

346 """Get the PAD token ID.""" 

347 return int(self.residue_set.PAD_INDEX) 

348 

349 def decode(self, sequence: Peptide) -> list[str]: 

350 """Decode a single sequence of AA IDs.""" 

351 # Note: Sequence is reversed as InstaNovo predicts right-to-left. 

352 # We reverse the sequence again when decoding to ensure 

353 # the decoder outputs forward sequences. 

354 return self.residue_set.decode(sequence, reverse=True) # type: ignore 

355 

356 def idx_to_aa(self, idx: Peptide) -> list[str]: 

357 """Decode a single sample of indices to aa list.""" 

358 idx = idx.cpu().numpy() 

359 t = [] 

360 for i in idx: 

361 if i == self.eos_id: 

362 break 

363 if i == self.bos_id or i == self.pad_id: 

364 continue 

365 t.append(i) 

366 return [self.i2s[x.item()] for x in t] 

367 

368 def batch_idx_to_aa(self, idx: Integer[Peptide, " batch"], reverse: bool) -> list[list[str]]: 

369 """Decode a batch of indices to aa lists.""" 

370 return [self.residue_set.decode(i, reverse=reverse) for i in idx] 

371 

372 def score_sequences( 

373 self, 

374 peptides: Integer[Peptide, " batch"] | list[str] | list[list[str]], 

375 peptides_mask: Bool[PeptideMask, " batch"] | None = None, 

376 spectra: Float[Spectrum, " batch"] | None = None, 

377 precursors: Float[PrecursorFeatures, " batch"] | None = None, 

378 spectra_mask: Bool[SpectrumMask, " batch"] | None = None, 

379 spectra_embedding: Float[SpectrumEmbedding, " batch"] | None = None, 

380 max_batch_size: int = 256, 

381 ) -> Float[ResidueLogProbabilities, "batch token"]: 

382 """Score a set of peptides.""" 

383 if (spectra is None and precursors is None) and spectra_embedding is None: 

384 raise ValueError("Either spectra and precursors or spectra_embedding must be provided") 

385 

386 if not isinstance(peptides, Tensor): 

387 peptides = [ 

388 self.residue_set.encode( 

389 self.residue_set.tokenize(x)[::-1], # type: ignore # ensure reversed 

390 add_eos=True, 

391 return_tensor="pt", 

392 ) 

393 for x in peptides 

394 ] 

395 

396 ll = torch.tensor([x.shape[0] for x in peptides], dtype=torch.long) # type: ignore 

397 peptides = nn.utils.rnn.pad_sequence(peptides, batch_first=True) 

398 peptides_mask = ( 

399 torch.arange(peptides.shape[1], dtype=torch.long)[None, :] >= ll[:, None] # type: ignore 

400 ) 

401 

402 device = spectra.device if spectra is not None else spectra_embedding.device # type: ignore 

403 

404 peptides = peptides.to(device) 

405 peptides_mask = peptides_mask.to(device) 

406 

407 # Automatically handle batching if the number of peptides is too large 

408 if peptides.shape[0] > max_batch_size: 

409 sequence_scores = [] 

410 for i in range(0, peptides.shape[0], max_batch_size): 

411 sub_batch = ( 

412 x[i : i + max_batch_size] if x is not None else None 

413 for x in ( 

414 peptides, 

415 peptides_mask, 

416 spectra, 

417 precursors, 

418 spectra_mask, 

419 spectra_embedding, 

420 ) 

421 ) 

422 sequence_scores.append(self.score_sequences(*sub_batch)) # type: ignore 

423 return torch.cat(sequence_scores, dim=0) 

424 

425 with torch.no_grad(): 

426 if spectra_embedding is None: 

427 if self.use_flash_attention: 

428 spectra_embedding, spectra_mask = self._flash_encoder(spectra, precursors, spectra_mask) 

429 else: 

430 spectra_embedding, spectra_mask = self._encoder(spectra, precursors, spectra_mask) 

431 

432 if self.use_flash_attention: 

433 logits = self._flash_decoder(spectra_embedding, peptides, spectra_mask, peptides_mask, add_bos=True) 

434 else: 

435 logits = self._decoder(spectra_embedding, peptides, spectra_mask, peptides_mask, add_bos=True) 

436 

437 # Get log probabilities for all positions 

438 log_probs = torch.log_softmax(logits, -1) 

439 

440 # Gather log probabilities for each token in the sequence 

441 sequence_log_prob = torch.gather(log_probs, -1, peptides.unsqueeze(-1)).squeeze(-1) 

442 

443 # Zero out masked positions 

444 if peptides_mask is not None: 

445 sequence_log_prob = sequence_log_prob.masked_fill(peptides_mask, 0.0) 

446 

447 # Sum log probabilities across sequence length 

448 sequence_log_prob = sequence_log_prob.sum(dim=-1) 

449 

450 return sequence_log_prob.cpu() 

451 

452 def _encoder( 

453 self, 

454 x: Float[Spectrum, " batch"], 

455 p: Float[PrecursorFeatures, " batch"] | None = None, 

456 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None, 

457 ) -> Tuple[Float[SpectrumEmbedding, " batch"], Bool[SpectrumMask, " batch"]]: 

458 if self.conv_peak_encoder: 

459 x = self.conv_encoder(x) 

460 x_mask = torch.zeros((x.shape[0], x.shape[1]), device=x.device).bool() 

461 else: 

462 if x_mask is None: 

463 x_mask = ~x.sum(dim=2).bool() 

464 x = self.peak_encoder(x) 

465 

466 # Self-attention on latent spectra AND peaks 

467 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1) 

468 x = torch.cat([latent_spectra, x], dim=1) 

469 latent_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device) 

470 x_mask = torch.cat([latent_mask, x_mask], dim=1) 

471 

472 x = self.encoder(x, src_key_padding_mask=x_mask) 

473 

474 # Prepare precursors 

475 if p is not None: 

476 masses = self.peak_encoder.encode_mass(p[:, None, [0]]) 

477 charges = self.charge_encoder(p[:, 1].int() - 1) 

478 precursors = masses + charges[:, None, :] 

479 

480 # Concatenate precursors 

481 x = torch.cat([precursors, x], dim=1) 

482 prec_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device) 

483 x_mask = torch.cat([prec_mask, x_mask], dim=1) 

484 

485 return x, x_mask 

486 

487 def _decoder( 

488 self, 

489 x: Float[Spectrum, " batch"], 

490 y: Integer[Peptide, " batch"], 

491 x_mask: Bool[SpectrumMask, " batch"], 

492 y_mask: Optional[Bool[PeptideMask, " batch"]] = None, 

493 add_bos: bool = True, 

494 ) -> Float[ResidueLogits, " batch"]: 

495 if y is None: 

496 y = torch.full((x.shape[0], 1), self.residue_set.SOS_INDEX, device=x.device) 

497 elif add_bos: 

498 bos = torch.ones((y.shape[0], 1), dtype=y.dtype, device=y.device) * self.residue_set.SOS_INDEX 

499 y = torch.cat([bos, y], dim=1) 

500 

501 if y_mask is not None: 

502 bos_mask = torch.zeros((y_mask.shape[0], 1), dtype=bool, device=y_mask.device) 

503 y_mask = torch.cat([bos_mask, y_mask], dim=1) 

504 

505 y = self.aa_embed(y) 

506 if y_mask is None: 

507 y_mask = ~y.sum(axis=2).bool() 

508 

509 # concat bos 

510 y = self.aa_pos_embed(y) 

511 

512 c_mask = self._get_causal_mask(y.shape[1]).to(y.device) 

513 

514 y_hat = self.decoder( 

515 y, 

516 x, 

517 tgt_mask=c_mask, 

518 tgt_key_padding_mask=y_mask, 

519 memory_key_padding_mask=x_mask, 

520 ) 

521 

522 return self.head(y_hat) 

523 

524 def _flash_encoder(self, x: Tensor, p: Tensor, x_mask: Tensor = None) -> tuple[Tensor, Tensor]: 

525 # Special mask for zero-indices 

526 # One is padded, zero is normal 

527 x_mask = (~x.sum(dim=2).bool()).float() 

528 

529 x = self.peak_encoder(x[:, :, [0]], x[:, :, [1]]) 

530 pad_spectrum = self.pad_spectrum.expand(x.shape[0], x.shape[1], -1) 

531 

532 # torch.compile doesn't allow dynamic sizes (returned by mask indexing) 

533 # x[x_mask] = pad_spectrum[x_mask].to(x.dtype) 

534 x = x * (1 - x_mask[:, :, None]) + pad_spectrum * (x_mask[:, :, None]) 

535 

536 # Self-attention on latent spectra AND peaks 

537 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1) 

538 x = torch.cat([latent_spectra, x], dim=1).contiguous() 

539 

540 try: 

541 from torch.nn.attention import SDPBackend, sdpa_kernel 

542 except ImportError: 

543 raise ImportError( 

544 "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version" 

545 ) from None 

546 

547 with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 

548 x = self.encoder(x) 

549 

550 # Prepare precursors 

551 masses = self.peak_encoder.encode_mass(p[:, None, [0]]) 

552 charges = self.charge_encoder(p[:, 1].int() - 1) 

553 precursors = masses + charges[:, None, :] 

554 

555 # Concatenate precursors 

556 x = torch.cat([precursors, x], dim=1).contiguous() 

557 

558 return x, None 

559 

560 def _flash_decoder( 

561 self, 

562 x: Tensor, 

563 y: Tensor, 

564 x_mask: Tensor, 

565 y_mask: Tensor = None, 

566 add_bos: bool = True, 

567 ) -> Tensor: 

568 if y is None: 

569 y = torch.full((x.shape[0], 1), self.residue_set.SOS_INDEX, device=x.device) 

570 elif add_bos: 

571 bos = torch.ones((y.shape[0], 1), dtype=y.dtype, device=y.device) * self.residue_set.SOS_INDEX 

572 y = torch.cat([bos, y], dim=1) 

573 

574 y = self.aa_embed(y) 

575 

576 # concat bos 

577 y = self.aa_pos_embed(y) 

578 

579 c_mask = self._get_causal_mask(y.shape[1]).to(y.device) 

580 

581 try: 

582 from torch.nn.attention import SDPBackend, sdpa_kernel 

583 except ImportError: 

584 raise ImportError( 

585 "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version" 

586 ) from None 

587 

588 with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 

589 y_hat = self.decoder(y, x, tgt_mask=c_mask) 

590 

591 return self.head(y_hat) 

592 

593 

594def _whitelist_torch_omegaconf() -> None: 

595 """Whitelist specific modules for loading configs from checkpoints.""" 

596 # This is done to safeguard against arbitrary code execution from checkpoints. 

597 from collections import defaultdict 

598 from typing import Any 

599 

600 from omegaconf.base import ContainerMetadata, Metadata 

601 from omegaconf.listconfig import ListConfig 

602 from omegaconf.nodes import AnyNode 

603 

604 torch.serialization.add_safe_globals( 

605 [ 

606 DictConfig, 

607 ContainerMetadata, 

608 Metadata, 

609 ListConfig, 

610 AnyNode, 

611 Any, # Only used for type hinting in omegaconf. 

612 defaultdict, 

613 dict, 

614 list, 

615 int, 

616 ] 

617 )