Coverage for instanovo/diffusion/layers.py: 86%
36 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Optional, Tuple
5import torch
6import torch.nn as nn
7from jaxtyping import Bool, Float
9from instanovo.__init__ import console
10from instanovo.transformer.layers import ConvPeakEmbedding, MultiScalePeakEmbedding
11from instanovo.types import Spectrum, SpectrumEmbedding, SpectrumMask
12from instanovo.utils.colorlogging import ColorLog
14logger = ColorLog(console, __name__).logger
17class TransformerEncoder(nn.Module):
18 """A Transformer encoder for input mass spectra.
20 Parameters
21 ----------
22 dim_model : int, optional
23 The latent dimensionality to represent peaks in the mass spectrum.
24 n_head : int, optional
25 The number of attention heads in each layer. ``dim_model`` must be
26 divisible by ``n_head``.
27 dim_feedforward : int, optional
28 The dimensionality of the fully connected layers in the Transformer
29 layers of the model.
30 n_layers : int, optional
31 The number of Transformer layers.
32 dropout : float, optional
33 The dropout probability for all layers.
34 """
36 def __init__(
37 self,
38 dim_model: int = 128,
39 n_head: int = 8,
40 dim_feedforward: int = 1024,
41 n_layers: int = 1,
42 dropout: float = 0.0,
43 use_flash_attention: bool = False,
44 conv_peak_encoder: bool = False,
45 peak_embedding_dtype: torch.dtype | str = torch.float64,
46 ) -> None:
47 """Initialise a TransformerEncoder."""
48 super().__init__()
49 self.use_flash_attention = use_flash_attention
50 self.conv_peak_encoder = conv_peak_encoder
52 self.latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
54 if self.use_flash_attention:
55 # All input spectra are padded to some max length
56 # Pad spectrum replaces zeros in input spectra
57 # This is for flash attention (no masks allowed)
58 self.pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
60 # Encoder
61 self.peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype)
62 if self.conv_peak_encoder:
63 self.conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout)
65 encoder_layer = nn.TransformerEncoderLayer(
66 d_model=dim_model,
67 nhead=n_head,
68 dim_feedforward=dim_feedforward,
69 batch_first=True,
70 dropout=0 if self.use_flash_attention else dropout,
71 )
72 self.encoder = nn.TransformerEncoder(
73 encoder_layer,
74 num_layers=n_layers,
75 # enable_nested_tensor=False, TODO: Figure out the correct way to handle this
76 )
78 def forward(
79 self,
80 x: Float[Spectrum, " batch"],
81 x_mask: Optional[Bool[SpectrumMask, " batch"]] = None,
82 ) -> Tuple[Float[SpectrumEmbedding, " batch"], Bool[SpectrumMask, " batch"]]:
83 """The forward pass.
85 Parameters
86 ----------
87 x : torch.Tensor of shape (n_spectra, n_peaks, 2)
88 The spectra to embed. Axis 0 represents a mass spectrum, axis 1
89 contains the peaks in the mass spectrum, and axis 2 is essentially
90 a 2-tuple specifying the m/z-intensity pair for each peak. These
91 should be zero-padded, such that all of the spectra in the batch
92 are the same length.
93 x_mask: torch.Tensor
94 Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)
96 Returns:
97 -------
98 latent : torch.Tensor of shape (n_spectra, n_peaks + 1, dim_model)
99 The latent representations for the spectrum and each of its
100 peaks.
101 mem_mask : torch.Tensor
102 The memory mask specifying which elements were padding in X.
103 """
104 if self.conv_peak_encoder:
105 x = self.conv_encoder(x)
106 x_mask = torch.zeros((x.shape[0], x.shape[1]), device=x.device).bool()
107 else:
108 if x_mask is None:
109 x_mask = ~x.sum(dim=2).bool()
110 x = self.peak_encoder(x)
112 # Self-attention on latent spectra AND peaks
113 latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1)
114 x = torch.cat([latent_spectra, x], dim=1)
115 latent_mask = torch.zeros((x_mask.shape[0], 1), dtype=bool, device=x_mask.device)
116 x_mask = torch.cat([latent_mask, x_mask], dim=1)
118 x = self.encoder(x, src_key_padding_mask=x_mask)
120 return x, x_mask