Coverage for instanovo/transformer/layers.py: 88%
49 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import math
5import numpy as np
6import torch
7from jaxtyping import Float
8from torch import Tensor, nn
10from instanovo.types import Spectrum, SpectrumEmbedding
13class PositionalEncoding(nn.Module):
14 """Standard sinusoidal positional encoding."""
16 def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
17 super().__init__()
18 self.dropout = nn.Dropout(p=dropout)
20 position = torch.arange(max_len).unsqueeze(1)
21 div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
22 pe = torch.zeros(1, max_len, d_model)
23 pe[0, :, 0::2] = torch.sin(position * div_term)
24 pe[0, :, 1::2] = torch.cos(position * div_term)
25 self.register_buffer("pe", pe)
27 def forward(self, x: Float[Tensor, "token batch embedding"]) -> Float[Tensor, "token batch embedding"]:
28 """Positional encoding forward pass.
30 Arguments:
31 x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
32 """
33 x = x + self.pe[:, : x.size(1)]
34 return self.dropout(x)
37class MultiScalePeakEmbedding(nn.Module):
38 """Multi-scale sinusoidal embedding based on Voronov et. al."""
40 def __init__(self, h_size: int, dropout: float = 0, float_dtype: torch.dtype | str = torch.float64) -> None:
41 super().__init__()
42 self.h_size = h_size
43 self.float_dtype = getattr(torch, float_dtype, None) if isinstance(float_dtype, str) else float_dtype
44 if self.float_dtype is None:
45 raise ValueError(f"Unknown torch dtype string: {float_dtype}")
47 self.mlp = nn.Sequential(
48 nn.Linear(h_size, h_size),
49 nn.ReLU(),
50 nn.Dropout(dropout),
51 nn.Linear(h_size, h_size),
52 nn.Dropout(dropout),
53 )
55 self.head = nn.Sequential(
56 nn.Linear(h_size + 1, h_size),
57 nn.ReLU(),
58 nn.Dropout(dropout),
59 nn.Linear(h_size, h_size),
60 nn.Dropout(dropout),
61 )
63 freqs = 2 * np.pi / torch.logspace(-2, -3, int(h_size / 2), dtype=self.float_dtype)
64 self.register_buffer("freqs", freqs)
66 # @torch.autocast("cuda", dtype=torch.float32)
67 def forward(self, spectra: Float[Spectrum, " batch"]) -> Float[SpectrumEmbedding, " batch"]:
68 """Encode peaks."""
69 mz_values, intensities = spectra[:, :, [0]], spectra[:, :, [1]]
70 x = self.encode_mass(mz_values)
71 x = self.mlp(x)
72 x = torch.cat([x, intensities], axis=2)
73 return self.head(x)
75 def encode_mass(self, x: Float[Tensor, " batch"]) -> Float[Tensor, "batch embedding"]:
76 """Encode mz."""
77 x = self.freqs[None, None, :] * x
78 x = torch.cat([torch.sin(x), torch.cos(x)], axis=2)
79 return x.float()
82class ConvPeakEmbedding(nn.Module):
83 """Convolutional peak embedding."""
85 def __init__(self, h_size: int, dropout: float = 0) -> None:
86 super().__init__()
87 self.h_size = h_size
89 self.conv = nn.Sequential(
90 nn.Conv1d(1, h_size // 4, kernel_size=40_000, stride=100, padding=40_000 // 2 - 1),
91 nn.ReLU(),
92 nn.Dropout(),
93 nn.Conv1d(h_size // 4, h_size, kernel_size=5, stride=1, padding=1),
94 nn.ReLU(),
95 nn.Dropout(),
96 )
98 def forward(self, x: Tensor) -> Tensor:
99 """Conv peak embedding."""
100 x = x.unsqueeze(1)
101 return self.conv(x).transpose(-1, -2)