Skip to content

Layers

layers

logger = ColorLog(console, __name__).logger module-attribute

TransformerEncoder(dim_model: int = 128, n_head: int = 8, dim_feedforward: int = 1024, n_layers: int = 1, dropout: float = 0.0, use_flash_attention: bool = False, conv_peak_encoder: bool = False, peak_embedding_dtype: torch.dtype | str = torch.float64)

Bases: Module

A Transformer encoder for input mass spectra.

Parameters

dim_model : int, optional The latent dimensionality to represent peaks in the mass spectrum. n_head : int, optional The number of attention heads in each layer. dim_model must be divisible by n_head. dim_feedforward : int, optional The dimensionality of the fully connected layers in the Transformer layers of the model. n_layers : int, optional The number of Transformer layers. dropout : float, optional The dropout probability for all layers.

Initialise a TransformerEncoder.

use_flash_attention = use_flash_attention instance-attribute

conv_peak_encoder = conv_peak_encoder instance-attribute

latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) instance-attribute

pad_spectrum = nn.Parameter(torch.randn(1, 1, dim_model)) instance-attribute

peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout, float_dtype=peak_embedding_dtype) instance-attribute

conv_encoder = ConvPeakEmbedding(dim_model, dropout=dropout) instance-attribute

encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) instance-attribute

forward(x: Float[Spectrum, ' batch'], x_mask: Optional[Bool[SpectrumMask, ' batch']] = None) -> Tuple[Float[SpectrumEmbedding, ' batch'], Bool[SpectrumMask, ' batch']]

The forward pass.

Parameters

x : torch.Tensor of shape (n_spectra, n_peaks, 2) The spectra to embed. Axis 0 represents a mass spectrum, axis 1 contains the peaks in the mass spectrum, and axis 2 is essentially a 2-tuple specifying the m/z-intensity pair for each peak. These should be zero-padded, such that all of the spectra in the batch are the same length. x_mask: torch.Tensor Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)

Returns:

latent : torch.Tensor of shape (n_spectra, n_peaks + 1, dim_model) The latent representations for the spectrum and each of its peaks. mem_mask : torch.Tensor The memory mask specifying which elements were padding in X.