Coverage for instanovo/transformer/model.py: 55%
264 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import json
4import os
5from importlib import resources
6from pathlib import Path
7from typing import Optional, Tuple
8from urllib.parse import urlsplit
10import torch
11from jaxtyping import Bool, Float, Integer
12from omegaconf import DictConfig, OmegaConf, open_dict
13from torch import Tensor, nn
15from instanovo.__init__ import console
16from instanovo.constants import LEGACY_PTM_TO_UNIMOD, MAX_SEQUENCE_LENGTH
17from instanovo.inference import Decodable
18from instanovo.transformer.layers import (
19 ConvPeakEmbedding,
20 MultiScalePeakEmbedding,
21 PositionalEncoding,
22)
23from instanovo.types import (
24 DiscretizedMass,
25 Peptide,
26 PeptideMask,
27 PrecursorFeatures,
28 ResidueLogits,
29 ResidueLogProbabilities,
30 Spectrum,
31 SpectrumEmbedding,
32 SpectrumMask,
33)
34from instanovo.utils.colorlogging import ColorLog
35from instanovo.utils.file_downloader import download_file
36from instanovo.utils.residues import ResidueSet
38MODEL_TYPE = "transformer"
41logger = ColorLog(console, __name__).logger
44class InstaNovo(nn.Module, Decodable):
45 """The Instanovo model."""
47 def __init__(
48 self,
49 residue_set: ResidueSet,
50 dim_model: int = 768,
51 n_head: int = 16,
52 dim_feedforward: int = 2048,
53 encoder_layers: int = 9,
54 decoder_layers: int = 9,
55 dropout: float = 0.1,
56 max_charge: int = 5,
57 use_flash_attention: bool = False,
58 conv_peak_encoder: bool = False,
59 peak_embedding_dtype: torch.dtype | str = torch.float64,
60 ) -> None:
61 super().__init__()
62 self._residue_set = residue_set
63 self.vocab_size = len(residue_set)
64 self.use_flash_attention = use_flash_attention
65 self.conv_peak_encoder = conv_peak_encoder
67 self.latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
69 if self.use_flash_attention:
70 # All input spectra are padded to some max length
71 # Pad spectrum replaces zeros in input spectra
72 # This is for flash attention (no masks allowed)
73 self.pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
75 # Encoder
76 self.peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype)
77 if self.conv_peak_encoder:
78 self.conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout)
80 encoder_layer = nn.TransformerEncoderLayer(
81 d_model=dim_model,
82 nhead=n_head,
83 dim_feedforward=dim_feedforward,
84 batch_first=True,
85 dropout=0 if self.use_flash_attention else dropout,
86 )
87 self.encoder = nn.TransformerEncoder(
88 encoder_layer,
89 num_layers=encoder_layers,
90 # enable_nested_tensor=False, TODO: Figure out the correct way to handle this
91 )
93 # Decoder
94 self.aa_embed = nn.Embedding(self.vocab_size, dim_model, padding_idx=0)
96 self.aa_pos_embed = PositionalEncoding(dim_model, dropout, max_len=MAX_SEQUENCE_LENGTH)
98 decoder_layer = nn.TransformerDecoderLayer(
99 d_model=dim_model,
100 nhead=n_head,
101 dim_feedforward=dim_feedforward,
102 batch_first=True,
103 dropout=0 if self.use_flash_attention else dropout,
104 )
105 self.decoder = nn.TransformerDecoder(
106 decoder_layer,
107 num_layers=decoder_layers,
108 )
110 self.head = nn.Linear(dim_model, self.vocab_size)
111 self.charge_encoder = nn.Embedding(max_charge, dim_model)
113 @property
114 def residue_set(self) -> ResidueSet:
115 """Every model must have a `residue_set` attribute."""
116 return self._residue_set
118 @staticmethod
119 def _get_causal_mask(seq_len: int, return_float: bool = False) -> PeptideMask:
120 mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
121 if return_float:
122 return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
123 return ~mask.bool()
125 @staticmethod
126 def get_pretrained() -> list[str]:
127 """Get a list of pretrained model ids."""
128 # Load the models.json file
129 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
130 models_config = json.load(f)
132 if MODEL_TYPE not in models_config:
133 return []
135 return list(models_config[MODEL_TYPE].keys())
137 @classmethod
138 def load(
139 cls, path: str, update_residues_to_unimod: bool = True, override_config: DictConfig | dict | None = None
140 ) -> tuple["InstaNovo", "DictConfig"]:
141 """Load model from checkpoint path.
143 Args:
144 path (str): Path to checkpoint file.
145 update_residues_to_unimod (bool): Update residues to unimod, defaults to True.
146 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None.
148 Returns:
149 tuple[InstaNovo, DictConfig]: Tuple of model and config.
150 """
151 # Add to allow list
152 _whitelist_torch_omegaconf()
153 ckpt = torch.load(path, map_location="cpu", weights_only=True)
155 config = ckpt["config"]
157 if override_config is not None:
158 if not isinstance(config, DictConfig):
159 config = OmegaConf.create(config)
160 with open_dict(config):
161 config.update(override_config)
163 # TODO: Remove
164 if "state_dict" not in ckpt:
165 ckpt["state_dict"] = ckpt["model"]
167 # check if PTL checkpoint
168 if all(x.startswith("model") for x in ckpt["state_dict"].keys()):
169 ckpt["state_dict"] = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
171 if "residues" not in ckpt:
172 # Legacy format
173 residues = dict(config["residues"])
174 else:
175 # TODO: Remove
176 # residues = dict(ckpt["residues"].get("residues", {}))
177 residues = ckpt["residues"]
179 if update_residues_to_unimod:
180 residues = {LEGACY_PTM_TO_UNIMOD[k] if k in LEGACY_PTM_TO_UNIMOD else k: v for k, v in residues.items()}
182 residue_set = ResidueSet(residues)
184 model = cls(
185 residue_set=residue_set,
186 dim_model=config["dim_model"],
187 n_head=config["n_head"],
188 dim_feedforward=config["dim_feedforward"],
189 encoder_layers=config.get("encoder_layers", config.get("n_layers", 9)),
190 decoder_layers=config.get("decoder_layers", config.get("n_layers", 9)),
191 dropout=config["dropout"],
192 max_charge=config["max_charge"],
193 use_flash_attention=config.get("use_flash_attention", False),
194 conv_peak_encoder=config.get("conv_peak_encoder", False),
195 peak_embedding_dtype=config.get("peak_embedding_dtype", torch.float64),
196 )
197 model.load_state_dict(ckpt["state_dict"])
199 return model, config
201 @classmethod
202 def from_pretrained(
203 cls, model_id: str, update_residues_to_unimod: bool = True, override_config: DictConfig | dict | None = None
204 ) -> tuple["InstaNovo", "DictConfig"]:
205 """Download and load by model id or model path.
207 Args:
208 model_id (str): Model id or model path.
209 update_residues_to_unimod (bool): Update residues to unimod, defaults to True.
210 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None.
212 Returns:
213 tuple[InstaNovo, DictConfig]: Tuple of model and config.
214 """
215 # TODO Refactor to use across methods
216 # Check if model_id is a local file path
217 if "/" in model_id or "\\" in model_id or model_id.endswith(".ckpt"):
218 if os.path.isfile(model_id):
219 return cls.load(model_id, update_residues_to_unimod=update_residues_to_unimod, override_config=override_config)
220 else:
221 raise FileNotFoundError(f"No file found at path: {model_id}")
223 # Load the models.json file
224 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
225 models_config = json.load(f)
227 # Find the model in the config
228 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]:
229 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]")
231 model_info = models_config[MODEL_TYPE][model_id]
232 url = model_info["remote"]
234 # Create cache directory if it doesn't exist
235 cache_dir = Path.home() / ".cache" / "instanovo"
236 cache_dir.mkdir(parents=True, exist_ok=True)
238 # Generate a filename for the cached model
239 file_name = urlsplit(url).path.split("/")[-1]
240 cached_file = cache_dir / file_name
242 # Check if the file is already cached
243 if not cached_file.exists():
244 download_file(url, cached_file, model_id, file_name)
246 else:
247 logger.info(f"Model {model_id} already cached at {cached_file}")
249 try:
250 # Load and return the model
251 logger.info(f"Loading model {model_id} (remote)")
252 return cls.load(str(cached_file), update_residues_to_unimod=update_residues_to_unimod, override_config=override_config)
253 except Exception as e:
254 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}")
255 if cached_file.exists():
256 cached_file.unlink()
258 download_file(url, cached_file, model_id, file_name)
259 logger.info(f"Loading newly downloaded model {model_id}")
260 return cls.load(str(cached_file), update_residues_to_unimod=update_residues_to_unimod, override_config=override_config)
262 def forward(
263 self,
264 x: Float[Spectrum, " batch"],
265 p: Float[PrecursorFeatures, " batch"],
266 y: Integer[Peptide, " batch"],
267 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None,
268 y_mask: Optional[Bool[PeptideMask, " batch"]] = None,
269 add_bos: bool = True,
270 return_encoder_output: bool = False,
271 ) -> Float[ResidueLogits, "batch token+1"]:
272 """Model forward pass.
274 Args:
275 x: Spectra, float Tensor (batch, n_peaks, 2)
276 p: Precursors, float Tensor (batch, 3)
277 y: Peptide, long Tensor (batch, seq_len, vocab)
278 x_mask: Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)
279 y_mask: Peptide padding mask, bool Tensor (batch, seq_len)
280 add_bos: Force add a <s> prefix to y, bool
282 Returns:
283 logits: float Tensor (batch, n, vocab_size),
284 (batch, n+1, vocab_size) if add_bos==True.
285 """
286 if self.use_flash_attention:
287 x, x_mask = self._flash_encoder(x, p, x_mask)
288 return self._flash_decoder(x, y, x_mask, y_mask, add_bos)
290 x, x_mask = self._encoder(x, p, x_mask)
291 y = self._decoder(x, y, x_mask, y_mask, add_bos)
292 if return_encoder_output:
293 return y, x
294 return y
296 def init(
297 self,
298 spectra: Float[Spectrum, " batch"],
299 precursors: Float[PrecursorFeatures, " batch"],
300 spectra_mask: Optional[Bool[SpectrumMask, " batch"]] = None,
301 ) -> Tuple[
302 Tuple[Float[Spectrum, " batch"], Bool[SpectrumMask, " batch"]],
303 Float[ResidueLogProbabilities, "batch token"],
304 ]:
305 """Initialise model encoder."""
306 if self.use_flash_attention:
307 spectra, _ = self._encoder(spectra, precursors, None)
308 logits = self._decoder(spectra, None, None, None, add_bos=False)
309 return (
310 spectra,
311 torch.zeros(spectra.shape[0], spectra.shape[1]).to(spectra.device),
312 ), torch.log_softmax(logits[:, -1, :], -1)
314 spectra, spectra_mask = self._encoder(spectra, precursors, spectra_mask)
315 logits = self._decoder(spectra, None, spectra_mask, None, add_bos=False)
316 return (spectra, spectra_mask), torch.log_softmax(logits[:, -1, :], -1)
318 def score_candidates(
319 self,
320 sequences: Integer[Peptide, " batch"],
321 precursor_mass_charge: Float[PrecursorFeatures, " batch"],
322 spectra: Float[Spectrum, " batch"],
323 spectra_mask: Bool[SpectrumMask, " batch"],
324 ) -> Float[ResidueLogProbabilities, "batch token"]:
325 """Score a set of candidate sequences."""
326 if self.use_flash_attention:
327 logits = self._flash_decoder(spectra, sequences, None, None, add_bos=True)
328 else:
329 logits = self._decoder(spectra, sequences, spectra_mask, None, add_bos=True)
331 return torch.log_softmax(logits[:, -1, :], -1)
333 def get_residue_masses(self, mass_scale: int) -> Integer[DiscretizedMass, " residue"]:
334 """Get the scaled masses of all residues."""
335 residue_masses = torch.zeros(len(self.residue_set), dtype=torch.int64)
336 for index, residue in self.residue_set.index_to_residue.items():
337 if residue in self.residue_set.residue_masses:
338 residue_masses[index] = round(mass_scale * self.residue_set.get_mass(residue))
339 return residue_masses
341 def get_eos_index(self) -> int:
342 """Get the EOS token ID."""
343 return int(self.residue_set.EOS_INDEX)
345 def get_empty_index(self) -> int:
346 """Get the PAD token ID."""
347 return int(self.residue_set.PAD_INDEX)
349 def decode(self, sequence: Peptide) -> list[str]:
350 """Decode a single sequence of AA IDs."""
351 # Note: Sequence is reversed as InstaNovo predicts right-to-left.
352 # We reverse the sequence again when decoding to ensure
353 # the decoder outputs forward sequences.
354 return self.residue_set.decode(sequence, reverse=True) # type: ignore
356 def idx_to_aa(self, idx: Peptide) -> list[str]:
357 """Decode a single sample of indices to aa list."""
358 idx = idx.cpu().numpy()
359 t = []
360 for i in idx:
361 if i == self.eos_id:
362 break
363 if i == self.bos_id or i == self.pad_id:
364 continue
365 t.append(i)
366 return [self.i2s[x.item()] for x in t]
368 def batch_idx_to_aa(self, idx: Integer[Peptide, " batch"], reverse: bool) -> list[list[str]]:
369 """Decode a batch of indices to aa lists."""
370 return [self.residue_set.decode(i, reverse=reverse) for i in idx]
372 def score_sequences(
373 self,
374 peptides: Integer[Peptide, " batch"] | list[str] | list[list[str]],
375 peptides_mask: Bool[PeptideMask, " batch"] | None = None,
376 spectra: Float[Spectrum, " batch"] | None = None,
377 precursors: Float[PrecursorFeatures, " batch"] | None = None,
378 spectra_mask: Bool[SpectrumMask, " batch"] | None = None,
379 spectra_embedding: Float[SpectrumEmbedding, " batch"] | None = None,
380 max_batch_size: int = 256,
381 ) -> Float[ResidueLogProbabilities, "batch token"]:
382 """Score a set of peptides."""
383 if (spectra is None and precursors is None) and spectra_embedding is None:
384 raise ValueError("Either spectra and precursors or spectra_embedding must be provided")
386 if not isinstance(peptides, Tensor):
387 peptides = [
388 self.residue_set.encode(
389 self.residue_set.tokenize(x)[::-1], # type: ignore # ensure reversed
390 add_eos=True,
391 return_tensor="pt",
392 )
393 for x in peptides
394 ]
396 ll = torch.tensor([x.shape[0] for x in peptides], dtype=torch.long) # type: ignore
397 peptides = nn.utils.rnn.pad_sequence(peptides, batch_first=True)
398 peptides_mask = (
399 torch.arange(peptides.shape[1], dtype=torch.long)[None, :] >= ll[:, None] # type: ignore
400 )
402 device = spectra.device if spectra is not None else spectra_embedding.device # type: ignore
404 peptides = peptides.to(device)
405 peptides_mask = peptides_mask.to(device)
407 # Automatically handle batching if the number of peptides is too large
408 if peptides.shape[0] > max_batch_size:
409 sequence_scores = []
410 for i in range(0, peptides.shape[0], max_batch_size):
411 sub_batch = (
412 x[i : i + max_batch_size] if x is not None else None
413 for x in (
414 peptides,
415 peptides_mask,
416 spectra,
417 precursors,
418 spectra_mask,
419 spectra_embedding,
420 )
421 )
422 sequence_scores.append(self.score_sequences(*sub_batch)) # type: ignore
423 return torch.cat(sequence_scores, dim=0)
425 with torch.no_grad():
426 if spectra_embedding is None:
427 if self.use_flash_attention:
428 spectra_embedding, spectra_mask = self._flash_encoder(spectra, precursors, spectra_mask)
429 else:
430 spectra_embedding, spectra_mask = self._encoder(spectra, precursors, spectra_mask)
432 if self.use_flash_attention:
433 logits = self._flash_decoder(spectra_embedding, peptides, spectra_mask, peptides_mask, add_bos=True)
434 else:
435 logits = self._decoder(spectra_embedding, peptides, spectra_mask, peptides_mask, add_bos=True)
437 # Get log probabilities for all positions
438 log_probs = torch.log_softmax(logits, -1)
440 # Gather log probabilities for each token in the sequence
441 sequence_log_prob = torch.gather(log_probs, -1, peptides.unsqueeze(-1)).squeeze(-1)
443 # Zero out masked positions
444 if peptides_mask is not None:
445 sequence_log_prob = sequence_log_prob.masked_fill(peptides_mask, 0.0)
447 # Sum log probabilities across sequence length
448 sequence_log_prob = sequence_log_prob.sum(dim=-1)
450 return sequence_log_prob.cpu()
452 def _encoder(
453 self,
454 x: Float[Spectrum, " batch"],
455 p: Float[PrecursorFeatures, " batch"] | None = None,
456 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None,
457 ) -> Tuple[Float[SpectrumEmbedding, " batch"], Bool[SpectrumMask, " batch"]]:
458 if self.conv_peak_encoder:
459 x = self.conv_encoder(x)
460 x_mask = torch.zeros((x.shape[0], x.shape[1]), device=x.device).bool()
461 else:
462 if x_mask is None:
463 x_mask = ~x.sum(dim=2).bool()
464 x = self.peak_encoder(x)
466 # Self-attention on latent spectra AND peaks
467 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1)
468 x = torch.cat([latent_spectra, x], dim=1)
469 latent_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device)
470 x_mask = torch.cat([latent_mask, x_mask], dim=1)
472 x = self.encoder(x, src_key_padding_mask=x_mask)
474 # Prepare precursors
475 if p is not None:
476 masses = self.peak_encoder.encode_mass(p[:, None, [0]])
477 charges = self.charge_encoder(p[:, 1].int() - 1)
478 precursors = masses + charges[:, None, :]
480 # Concatenate precursors
481 x = torch.cat([precursors, x], dim=1)
482 prec_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device)
483 x_mask = torch.cat([prec_mask, x_mask], dim=1)
485 return x, x_mask
487 def _decoder(
488 self,
489 x: Float[Spectrum, " batch"],
490 y: Integer[Peptide, " batch"],
491 x_mask: Bool[SpectrumMask, " batch"],
492 y_mask: Optional[Bool[PeptideMask, " batch"]] = None,
493 add_bos: bool = True,
494 ) -> Float[ResidueLogits, " batch"]:
495 if y is None:
496 y = torch.full((x.shape[0], 1), self.residue_set.SOS_INDEX, device=x.device)
497 elif add_bos:
498 bos = torch.ones((y.shape[0], 1), dtype=y.dtype, device=y.device) * self.residue_set.SOS_INDEX
499 y = torch.cat([bos, y], dim=1)
501 if y_mask is not None:
502 bos_mask = torch.zeros((y_mask.shape[0], 1), dtype=bool, device=y_mask.device)
503 y_mask = torch.cat([bos_mask, y_mask], dim=1)
505 y = self.aa_embed(y)
506 if y_mask is None:
507 y_mask = ~y.sum(axis=2).bool()
509 # concat bos
510 y = self.aa_pos_embed(y)
512 c_mask = self._get_causal_mask(y.shape[1]).to(y.device)
514 y_hat = self.decoder(
515 y,
516 x,
517 tgt_mask=c_mask,
518 tgt_key_padding_mask=y_mask,
519 memory_key_padding_mask=x_mask,
520 )
522 return self.head(y_hat)
524 def _flash_encoder(self, x: Tensor, p: Tensor, x_mask: Tensor = None) -> tuple[Tensor, Tensor]:
525 # Special mask for zero-indices
526 # One is padded, zero is normal
527 x_mask = (~x.sum(dim=2).bool()).float()
529 x = self.peak_encoder(x[:, :, [0]], x[:, :, [1]])
530 pad_spectrum = self.pad_spectrum.expand(x.shape[0], x.shape[1], -1)
532 # torch.compile doesn't allow dynamic sizes (returned by mask indexing)
533 # x[x_mask] = pad_spectrum[x_mask].to(x.dtype)
534 x = x * (1 - x_mask[:, :, None]) + pad_spectrum * (x_mask[:, :, None])
536 # Self-attention on latent spectra AND peaks
537 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1)
538 x = torch.cat([latent_spectra, x], dim=1).contiguous()
540 try:
541 from torch.nn.attention import SDPBackend, sdpa_kernel
542 except ImportError:
543 raise ImportError(
544 "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version"
545 ) from None
547 with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
548 x = self.encoder(x)
550 # Prepare precursors
551 masses = self.peak_encoder.encode_mass(p[:, None, [0]])
552 charges = self.charge_encoder(p[:, 1].int() - 1)
553 precursors = masses + charges[:, None, :]
555 # Concatenate precursors
556 x = torch.cat([precursors, x], dim=1).contiguous()
558 return x, None
560 def _flash_decoder(
561 self,
562 x: Tensor,
563 y: Tensor,
564 x_mask: Tensor,
565 y_mask: Tensor = None,
566 add_bos: bool = True,
567 ) -> Tensor:
568 if y is None:
569 y = torch.full((x.shape[0], 1), self.residue_set.SOS_INDEX, device=x.device)
570 elif add_bos:
571 bos = torch.ones((y.shape[0], 1), dtype=y.dtype, device=y.device) * self.residue_set.SOS_INDEX
572 y = torch.cat([bos, y], dim=1)
574 y = self.aa_embed(y)
576 # concat bos
577 y = self.aa_pos_embed(y)
579 c_mask = self._get_causal_mask(y.shape[1]).to(y.device)
581 try:
582 from torch.nn.attention import SDPBackend, sdpa_kernel
583 except ImportError:
584 raise ImportError(
585 "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version"
586 ) from None
588 with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
589 y_hat = self.decoder(y, x, tgt_mask=c_mask)
591 return self.head(y_hat)
594def _whitelist_torch_omegaconf() -> None:
595 """Whitelist specific modules for loading configs from checkpoints."""
596 # This is done to safeguard against arbitrary code execution from checkpoints.
597 from collections import defaultdict
598 from typing import Any
600 from omegaconf.base import ContainerMetadata, Metadata
601 from omegaconf.listconfig import ListConfig
602 from omegaconf.nodes import AnyNode
604 torch.serialization.add_safe_globals(
605 [
606 DictConfig,
607 ContainerMetadata,
608 Metadata,
609 ListConfig,
610 AnyNode,
611 Any, # Only used for type hinting in omegaconf.
612 defaultdict,
613 dict,
614 list,
615 int,
616 ]
617 )