Coverage for instanovo/diffusion/layers.py: 86%

36 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Optional, Tuple 

4 

5import torch 

6import torch.nn as nn 

7from jaxtyping import Bool, Float 

8 

9from instanovo.__init__ import console 

10from instanovo.transformer.layers import ConvPeakEmbedding, MultiScalePeakEmbedding 

11from instanovo.types import Spectrum, SpectrumEmbedding, SpectrumMask 

12from instanovo.utils.colorlogging import ColorLog 

13 

14logger = ColorLog(console, __name__).logger 

15 

16 

17class TransformerEncoder(nn.Module): 

18 """A Transformer encoder for input mass spectra. 

19 

20 Parameters 

21 ---------- 

22 dim_model : int, optional 

23 The latent dimensionality to represent peaks in the mass spectrum. 

24 n_head : int, optional 

25 The number of attention heads in each layer. ``dim_model`` must be 

26 divisible by ``n_head``. 

27 dim_feedforward : int, optional 

28 The dimensionality of the fully connected layers in the Transformer 

29 layers of the model. 

30 n_layers : int, optional 

31 The number of Transformer layers. 

32 dropout : float, optional 

33 The dropout probability for all layers. 

34 """ 

35 

36 def __init__( 

37 self, 

38 dim_model: int = 128, 

39 n_head: int = 8, 

40 dim_feedforward: int = 1024, 

41 n_layers: int = 1, 

42 dropout: float = 0.0, 

43 use_flash_attention: bool = False, 

44 conv_peak_encoder: bool = False, 

45 peak_embedding_dtype: torch.dtype | str = torch.float64, 

46 ) -> None: 

47 """Initialise a TransformerEncoder.""" 

48 super().__init__() 

49 self.use_flash_attention = use_flash_attention 

50 self.conv_peak_encoder = conv_peak_encoder 

51 

52 self.latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) 

53 

54 if self.use_flash_attention: 

55 # All input spectra are padded to some max length 

56 # Pad spectrum replaces zeros in input spectra 

57 # This is for flash attention (no masks allowed) 

58 self.pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) 

59 

60 # Encoder 

61 self.peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype) 

62 if self.conv_peak_encoder: 

63 self.conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout) 

64 

65 encoder_layer = nn.TransformerEncoderLayer( 

66 d_model=dim_model, 

67 nhead=n_head, 

68 dim_feedforward=dim_feedforward, 

69 batch_first=True, 

70 dropout=0 if self.use_flash_attention else dropout, 

71 ) 

72 self.encoder = nn.TransformerEncoder( 

73 encoder_layer, 

74 num_layers=n_layers, 

75 # enable_nested_tensor=False, TODO: Figure out the correct way to handle this 

76 ) 

77 

78 def forward( 

79 self, 

80 x: Float[Spectrum, " batch"], 

81 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None, 

82 ) -> Tuple[Float[SpectrumEmbedding, " batch"], Bool[SpectrumMask, " batch"]]: 

83 """The forward pass. 

84 

85 Parameters 

86 ---------- 

87 x : torch.Tensor of shape (n_spectra, n_peaks, 2) 

88 The spectra to embed. Axis 0 represents a mass spectrum, axis 1 

89 contains the peaks in the mass spectrum, and axis 2 is essentially 

90 a 2-tuple specifying the m/z-intensity pair for each peak. These 

91 should be zero-padded, such that all of the spectra in the batch 

92 are the same length. 

93 x_mask: torch.Tensor 

94 Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks) 

95 

96 Returns: 

97 ------- 

98 latent : torch.Tensor of shape (n_spectra, n_peaks + 1, dim_model) 

99 The latent representations for the spectrum and each of its 

100 peaks. 

101 mem_mask : torch.Tensor 

102 The memory mask specifying which elements were padding in X. 

103 """ 

104 if self.conv_peak_encoder: 

105 x = self.conv_encoder(x) 

106 x_mask = torch.zeros((x.shape[0], x.shape[1]), device=x.device).bool() 

107 else: 

108 if x_mask is None: 

109 x_mask = ~x.sum(dim=2).bool() 

110 x = self.peak_encoder(x) 

111 

112 # Self-attention on latent spectra AND peaks 

113 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1) 

114 x = torch.cat([latent_spectra, x], dim=1) 

115 latent_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device) 

116 x_mask = torch.cat([latent_mask, x_mask], dim=1) 

117 

118 x = self.encoder(x, src_key_padding_mask=x_mask) 

119 

120 return x, x_mask