Coverage for instanovo/diffusion/model.py: 96%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Optional
5import torch
6from jaxtyping import Bool, Float, Integer
7from omegaconf import DictConfig
8from torch import Tensor, nn
9from transfusion.model import Pogfuse, TransFusion, timestep_embedding
11from instanovo.diffusion.layers import TransformerEncoder
12from instanovo.types import (
13 Peptide,
14 PeptideEmbedding,
15 PeptideMask,
16 PrecursorFeatures,
17 ResidueLogits,
18 Spectrum,
19 SpectrumEmbedding,
20 SpectrumMask,
21 TimeEmbedding,
22 TimeStep,
23)
26class MassSpectrumTransformer(Pogfuse):
27 """A transformer model specialised for encoding mass spectra."""
29 def forward(
30 self,
31 x: Float[PeptideEmbedding, " batch"],
32 t_emb: Float[TimeEmbedding, " batch"],
33 precursor_emb: Float[Tensor, "..."],
34 cond_emb: Optional[Float[SpectrumEmbedding, " batch"]] = None,
35 x_padding_mask: Optional[Bool[PeptideMask, " batch"]] = None,
36 cond_padding_mask: Optional[Bool[SpectrumMask, " batch"]] = None,
37 pos_bias: Optional[Float[Tensor, "..."]] = None,
38 ) -> Float[Tensor, "batch token embedding"]:
39 """Compute encodings with the model.
41 Forward with `x` (bs, seq_len, dim), summing `t_emb` (bs, dim) before the transformer layer,
42 and appending `conditioning_emb` (bs, seq_len2, dim) to the key/value pairs of the
43 attention. Also `pooled_conv_emb` (bs, 1, dim) is summed with the timestep embeddings
45 Optionally specify key/value padding for input `x` with `x_padding_mask` (bs, seq_len), and
46 optionally specify key/value padding mask for conditional embedding with `cond_padding_mask`
47 (bs, seq_len2). By default no padding is used. Good idea to use cond padding but not x
48 padding.
50 `pos_bias` is positional bias for wavlm-style attention gated relative position bias.
52 Returns `x` of same shape (bs, seq_len, dim)
53 """
54 # -----------------------
55 # 1. Get and add timestep embedding
56 t = self.t_layers(t_emb)[:, None] # (bs, 1, dim)
57 p = self.cond_pooled_layers(precursor_emb) # (bs, 1, dim)
58 x += t + p # (bs, seq_len, dim)
59 # -----------------------
60 # 2. Get and append conditioning embeddings
61 if self.add_cond_seq:
62 c = self.cond_layers(cond_emb) # (bs, seq_len2, dim)
63 else:
64 c = None
65 # -----------------------
66 # 3. Do transformer layer
67 # -- Self-attention block
68 x1, pos_bias = self._sa_block(
69 x,
70 c,
71 x_padding_mask=x_padding_mask,
72 c_padding_mask=cond_padding_mask,
73 pos_bias=pos_bias,
74 )
76 # -- Layer-norm with residual connection
77 x = self.norm1(x + x1)
79 # -- Layer-norm with feedfoward block and residual connection
80 x = self.norm2(x + self._ff_block(x))
82 return x, pos_bias
85class MassSpectrumTransFusion(TransFusion):
86 """Diffusion reconstruction model conditioned on mass spectra."""
88 def __init__(
89 self,
90 cfg: DictConfig, # ModelConfig,
91 max_transcript_len: int = 200,
92 ) -> None:
93 super().__init__(cfg, max_transcript_len)
94 layers = []
95 for i in range(cfg.layers):
96 add_cond_cross_attn = i in list(self.cfg.cond_cross_attn_layers)
97 layer = MassSpectrumTransformer(
98 self.cfg.dim,
99 self.cfg.t_emb_dim,
100 self.cfg.cond_emb_dim,
101 self.cfg.nheads,
102 add_cond_seq=add_cond_cross_attn,
103 dropout=self.cfg.dropout,
104 use_wavlm_attn=cfg.attention_type == "wavlm" and not add_cond_cross_attn,
105 wavlm_num_bucket=cfg.wavlm_num_bucket,
106 wavlm_max_dist=cfg.wavlm_max_dist,
107 has_rel_attn_bias=(cfg.attention_type == "wavlm" and i == 1),
108 )
109 # add relative attn bias at i=1 as that is first attn where we do not use
110 # cross attention.
111 layers.append(layer)
112 self.layers = nn.ModuleList(layers)
114 self.conditioning_pos_emb = None
116 self.encoder = TransformerEncoder(
117 dim_model=cfg.dim,
118 n_head=cfg.nheads,
119 dim_feedforward=cfg.dim_feedforward,
120 n_layers=cfg.get("encoder_layers", cfg.get("layers", None)),
121 dropout=cfg.dropout,
122 use_flash_attention=cfg.get("use_flash_attention", False),
123 conv_peak_encoder=cfg.get("conv_peak_encoder", False),
124 peak_embedding_dtype=cfg.get("peak_embedding_dtype", torch.float64),
125 )
127 # precursor embedding
128 self.charge_encoder = torch.nn.Embedding(cfg.max_charge, cfg.dim)
129 self.peak_encoder = self.encoder.peak_encoder
131 self.cache_spectra = None
132 self.cache_cond_emb = None
133 self.cache_cond_padding_mask = None
135 def forward(
136 self,
137 x: Integer[Peptide, " batch"],
138 t: Integer[TimeStep, " batch"],
139 spectra: Float[Spectrum, " batch"],
140 spectra_padding_mask: Bool[SpectrumMask, " batch"],
141 precursors: Float[PrecursorFeatures, " batch"],
142 x_padding_mask: Optional[Bool[PeptideMask, " batch"]] = None,
143 ) -> Float[ResidueLogits, "batch token"]:
144 """Transformer with conditioning cross attention.
146 - `x`: (bs, seq_len) long tensor of character indices
147 or (bs, seq_len, vocab_size) if cfg.diffusion_type == 'continuous'
148 - `t`: (bs, ) long tensor of timestep indices
149 - `cond_emb`: (bs, seq_len2, cond_emb_dim) if using wavlm encoder, else (bs, T)
150 - `x_padding_mask`: (bs, seq_len) if using wavlm encoder, else (bs, T)
151 - `cond_padding_mask`: (bs, seq_len2)
153 Returns logits (bs, seq_len, vocab_size)
154 """
155 # 1. Base: character, timestep embeddings and zeroing
156 bs = x.shape[0]
157 x = self.char_embedding(x) # (bs, seq_len, dim)
159 if self.cfg.pos_encoding == "relative":
160 x = self.pos_embedding(x)
161 else:
162 pos_emb = self.pos_embedding.weight[None].expand(bs, -1, -1) # (seq_len, dim) --> (bs, seq_len, dim)
163 x = x + pos_emb
165 t_emb = timestep_embedding(t, self.cfg.t_emb_dim, self.cfg.t_emb_max_period, dtype=spectra.dtype) # (bs, t_dim)
166 # 2. Classifier-free guidance: with prob cfg.drop_cond_prob, zero out
167 # and drop conditional probability
168 if self.training:
169 zero_cond_inds = torch.rand_like(t, dtype=spectra.dtype) < self.cfg.drop_cond_prob
170 else:
171 # never randomly zero when in eval mode
172 zero_cond_inds = torch.zeros_like(t, dtype=torch.bool)
173 if spectra_padding_mask.all():
174 # BUT, if all cond information is padded then we are
175 # obviously doing unconditional synthesis,
176 # so, force zero_cond_inds to be all ones
177 zero_cond_inds = ~zero_cond_inds
179 # 3. DENOVO calculate spectrum embedding here
180 if self.training:
181 cond_emb, cond_padding_mask = self.encoder(spectra, spectra_padding_mask)
182 else:
183 if self.cache_spectra is not None and torch.equal(self.cache_spectra, spectra):
184 cond_emb, cond_padding_mask = (
185 self.cache_cond_emb,
186 self.cache_cond_padding_mask,
187 )
188 else:
189 cond_emb, cond_padding_mask = self.encoder(spectra, spectra_padding_mask)
190 self.cache_spectra = spectra
191 self.cache_cond_emb = cond_emb
192 self.cache_cond_padding_mask = cond_padding_mask
194 # set mask for these conditional entries to true everywhere (i.e. mask them out)
195 masses = self.peak_encoder.encode_mass(precursors[:, None, [0]])
196 charges = self.charge_encoder(precursors[:, 1].int() - 1)
197 precursor_emb = masses + charges[:, None, :]
199 cond_padding_mask[zero_cond_inds] = True
200 cond_emb[zero_cond_inds] = 0
202 # 4. Iterate through layers
203 pos_bias = None
204 for layer in self.layers:
205 x, pos_bias = layer(
206 x,
207 t_emb,
208 precursor_emb,
209 cond_emb,
210 x_padding_mask,
211 cond_padding_mask,
212 pos_bias=pos_bias,
213 )
214 # 5. Pass through head to get logits
215 x = self.head(x) # (bs, seq_len, vocab size)
217 return x