Coverage for instanovo/diffusion/model.py: 96%

68 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Optional 

4 

5import torch 

6from jaxtyping import Bool, Float, Integer 

7from omegaconf import DictConfig 

8from torch import Tensor, nn 

9from transfusion.model import Pogfuse, TransFusion, timestep_embedding 

10 

11from instanovo.diffusion.layers import TransformerEncoder 

12from instanovo.types import ( 

13 Peptide, 

14 PeptideEmbedding, 

15 PeptideMask, 

16 PrecursorFeatures, 

17 ResidueLogits, 

18 Spectrum, 

19 SpectrumEmbedding, 

20 SpectrumMask, 

21 TimeEmbedding, 

22 TimeStep, 

23) 

24 

25 

26class MassSpectrumTransformer(Pogfuse): 

27 """A transformer model specialised for encoding mass spectra.""" 

28 

29 def forward( 

30 self, 

31 x: Float[PeptideEmbedding, " batch"], 

32 t_emb: Float[TimeEmbedding, " batch"], 

33 precursor_emb: Float[Tensor, "..."], 

34 cond_emb: Optional[Float[SpectrumEmbedding, " batch"]] = None, 

35 x_padding_mask: Optional[Bool[PeptideMask, " batch"]] = None, 

36 cond_padding_mask: Optional[Bool[SpectrumMask, " batch"]] = None, 

37 pos_bias: Optional[Float[Tensor, "..."]] = None, 

38 ) -> Float[Tensor, "batch token embedding"]: 

39 """Compute encodings with the model. 

40 

41 Forward with `x` (bs, seq_len, dim), summing `t_emb` (bs, dim) before the transformer layer, 

42 and appending `conditioning_emb` (bs, seq_len2, dim) to the key/value pairs of the 

43 attention. Also `pooled_conv_emb` (bs, 1, dim) is summed with the timestep embeddings 

44 

45 Optionally specify key/value padding for input `x` with `x_padding_mask` (bs, seq_len), and 

46 optionally specify key/value padding mask for conditional embedding with `cond_padding_mask` 

47 (bs, seq_len2). By default no padding is used. Good idea to use cond padding but not x 

48 padding. 

49 

50 `pos_bias` is positional bias for wavlm-style attention gated relative position bias. 

51 

52 Returns `x` of same shape (bs, seq_len, dim) 

53 """ 

54 # ----------------------- 

55 # 1. Get and add timestep embedding 

56 t = self.t_layers(t_emb)[:, None] # (bs, 1, dim) 

57 p = self.cond_pooled_layers(precursor_emb) # (bs, 1, dim) 

58 x += t + p # (bs, seq_len, dim) 

59 # ----------------------- 

60 # 2. Get and append conditioning embeddings 

61 if self.add_cond_seq: 

62 c = self.cond_layers(cond_emb) # (bs, seq_len2, dim) 

63 else: 

64 c = None 

65 # ----------------------- 

66 # 3. Do transformer layer 

67 # -- Self-attention block 

68 x1, pos_bias = self._sa_block( 

69 x, 

70 c, 

71 x_padding_mask=x_padding_mask, 

72 c_padding_mask=cond_padding_mask, 

73 pos_bias=pos_bias, 

74 ) 

75 

76 # -- Layer-norm with residual connection 

77 x = self.norm1(x + x1) 

78 

79 # -- Layer-norm with feedfoward block and residual connection 

80 x = self.norm2(x + self._ff_block(x)) 

81 

82 return x, pos_bias 

83 

84 

85class MassSpectrumTransFusion(TransFusion): 

86 """Diffusion reconstruction model conditioned on mass spectra.""" 

87 

88 def __init__( 

89 self, 

90 cfg: DictConfig, # ModelConfig, 

91 max_transcript_len: int = 200, 

92 ) -> None: 

93 super().__init__(cfg, max_transcript_len) 

94 layers = [] 

95 for i in range(cfg.layers): 

96 add_cond_cross_attn = i in list(self.cfg.cond_cross_attn_layers) 

97 layer = MassSpectrumTransformer( 

98 self.cfg.dim, 

99 self.cfg.t_emb_dim, 

100 self.cfg.cond_emb_dim, 

101 self.cfg.nheads, 

102 add_cond_seq=add_cond_cross_attn, 

103 dropout=self.cfg.dropout, 

104 use_wavlm_attn=cfg.attention_type == "wavlm" and not add_cond_cross_attn, 

105 wavlm_num_bucket=cfg.wavlm_num_bucket, 

106 wavlm_max_dist=cfg.wavlm_max_dist, 

107 has_rel_attn_bias=(cfg.attention_type == "wavlm" and i == 1), 

108 ) 

109 # add relative attn bias at i=1 as that is first attn where we do not use 

110 # cross attention. 

111 layers.append(layer) 

112 self.layers = nn.ModuleList(layers) 

113 

114 self.conditioning_pos_emb = None 

115 

116 self.encoder = TransformerEncoder( 

117 dim_model=cfg.dim, 

118 n_head=cfg.nheads, 

119 dim_feedforward=cfg.dim_feedforward, 

120 n_layers=cfg.get("encoder_layers", cfg.get("layers", None)), 

121 dropout=cfg.dropout, 

122 use_flash_attention=cfg.get("use_flash_attention", False), 

123 conv_peak_encoder=cfg.get("conv_peak_encoder", False), 

124 peak_embedding_dtype=cfg.get("peak_embedding_dtype", torch.float64), 

125 ) 

126 

127 # precursor embedding 

128 self.charge_encoder = torch.nn.Embedding(cfg.max_charge, cfg.dim) 

129 self.peak_encoder = self.encoder.peak_encoder 

130 

131 self.cache_spectra = None 

132 self.cache_cond_emb = None 

133 self.cache_cond_padding_mask = None 

134 

135 def forward( 

136 self, 

137 x: Integer[Peptide, " batch"], 

138 t: Integer[TimeStep, " batch"], 

139 spectra: Float[Spectrum, " batch"], 

140 spectra_padding_mask: Bool[SpectrumMask, " batch"], 

141 precursors: Float[PrecursorFeatures, " batch"], 

142 x_padding_mask: Optional[Bool[PeptideMask, " batch"]] = None, 

143 ) -> Float[ResidueLogits, "batch token"]: 

144 """Transformer with conditioning cross attention. 

145 

146 - `x`: (bs, seq_len) long tensor of character indices 

147 or (bs, seq_len, vocab_size) if cfg.diffusion_type == 'continuous' 

148 - `t`: (bs, ) long tensor of timestep indices 

149 - `cond_emb`: (bs, seq_len2, cond_emb_dim) if using wavlm encoder, else (bs, T) 

150 - `x_padding_mask`: (bs, seq_len) if using wavlm encoder, else (bs, T) 

151 - `cond_padding_mask`: (bs, seq_len2) 

152 

153 Returns logits (bs, seq_len, vocab_size) 

154 """ 

155 # 1. Base: character, timestep embeddings and zeroing 

156 bs = x.shape[0] 

157 x = self.char_embedding(x) # (bs, seq_len, dim) 

158 

159 if self.cfg.pos_encoding == "relative": 

160 x = self.pos_embedding(x) 

161 else: 

162 pos_emb = self.pos_embedding.weight[None].expand(bs, -1, -1) # (seq_len, dim) --> (bs, seq_len, dim) 

163 x = x + pos_emb 

164 

165 t_emb = timestep_embedding(t, self.cfg.t_emb_dim, self.cfg.t_emb_max_period, dtype=spectra.dtype) # (bs, t_dim) 

166 # 2. Classifier-free guidance: with prob cfg.drop_cond_prob, zero out 

167 # and drop conditional probability 

168 if self.training: 

169 zero_cond_inds = torch.rand_like(t, dtype=spectra.dtype) < self.cfg.drop_cond_prob 

170 else: 

171 # never randomly zero when in eval mode 

172 zero_cond_inds = torch.zeros_like(t, dtype=torch.bool) 

173 if spectra_padding_mask.all(): 

174 # BUT, if all cond information is padded then we are 

175 # obviously doing unconditional synthesis, 

176 # so, force zero_cond_inds to be all ones 

177 zero_cond_inds = ~zero_cond_inds 

178 

179 # 3. DENOVO calculate spectrum embedding here 

180 if self.training: 

181 cond_emb, cond_padding_mask = self.encoder(spectra, spectra_padding_mask) 

182 else: 

183 if self.cache_spectra is not None and torch.equal(self.cache_spectra, spectra): 

184 cond_emb, cond_padding_mask = ( 

185 self.cache_cond_emb, 

186 self.cache_cond_padding_mask, 

187 ) 

188 else: 

189 cond_emb, cond_padding_mask = self.encoder(spectra, spectra_padding_mask) 

190 self.cache_spectra = spectra 

191 self.cache_cond_emb = cond_emb 

192 self.cache_cond_padding_mask = cond_padding_mask 

193 

194 # set mask for these conditional entries to true everywhere (i.e. mask them out) 

195 masses = self.peak_encoder.encode_mass(precursors[:, None, [0]]) 

196 charges = self.charge_encoder(precursors[:, 1].int() - 1) 

197 precursor_emb = masses + charges[:, None, :] 

198 

199 cond_padding_mask[zero_cond_inds] = True 

200 cond_emb[zero_cond_inds] = 0 

201 

202 # 4. Iterate through layers 

203 pos_bias = None 

204 for layer in self.layers: 

205 x, pos_bias = layer( 

206 x, 

207 t_emb, 

208 precursor_emb, 

209 cond_emb, 

210 x_padding_mask, 

211 cond_padding_mask, 

212 pos_bias=pos_bias, 

213 ) 

214 # 5. Pass through head to get logits 

215 x = self.head(x) # (bs, seq_len, vocab size) 

216 

217 return x