Layers
layers
logger = ColorLog(console, __name__).logger
module-attribute
TransformerEncoder(dim_model: int = 128, n_head: int = 8, dim_feedforward: int = 1024, n_layers: int = 1, dropout: float = 0.0, use_flash_attention: bool = False, conv_peak_encoder: bool = False, peak_embedding_dtype: torch.dtype | str = torch.float64)
Bases: Module
A Transformer encoder for input mass spectra.
Parameters
dim_model : int, optional
The latent dimensionality to represent peaks in the mass spectrum.
n_head : int, optional
The number of attention heads in each layer. dim_model must be
divisible by n_head.
dim_feedforward : int, optional
The dimensionality of the fully connected layers in the Transformer
layers of the model.
n_layers : int, optional
The number of Transformer layers.
dropout : float, optional
The dropout probability for all layers.
Initialise a TransformerEncoder.
use_flash_attention = use_flash_attention
instance-attribute
conv_peak_encoder = conv_peak_encoder
instance-attribute
latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
instance-attribute
pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))
instance-attribute
peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype)
instance-attribute
conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout)
instance-attribute
encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
instance-attribute
forward(x: Float[Spectrum, ' batch'], x_mask: Optional[Bool[SpectrumMask, ' batch']] = None) -> Tuple[Float[SpectrumEmbedding, ' batch'], Bool[SpectrumMask, ' batch']]
The forward pass.
Parameters
x : torch.Tensor of shape (n_spectra, n_peaks, 2) The spectra to embed. Axis 0 represents a mass spectrum, axis 1 contains the peaks in the mass spectrum, and axis 2 is essentially a 2-tuple specifying the m/z-intensity pair for each peak. These should be zero-padded, such that all of the spectra in the batch are the same length. x_mask: torch.Tensor Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)
Returns:
latent : torch.Tensor of shape (n_spectra, n_peaks + 1, dim_model) The latent representations for the spectrum and each of its peaks. mem_mask : torch.Tensor The memory mask specifying which elements were padding in X.