Coverage for instanovo/transformer/layers.py: 88%

49 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import math 

4 

5import numpy as np 

6import torch 

7from jaxtyping import Float 

8from torch import Tensor, nn 

9 

10from instanovo.types import Spectrum, SpectrumEmbedding 

11 

12 

13class PositionalEncoding(nn.Module): 

14 """Standard sinusoidal positional encoding.""" 

15 

16 def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 

17 super().__init__() 

18 self.dropout = nn.Dropout(p=dropout) 

19 

20 position = torch.arange(max_len).unsqueeze(1) 

21 div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 

22 pe = torch.zeros(1, max_len, d_model) 

23 pe[0, :, 0::2] = torch.sin(position * div_term) 

24 pe[0, :, 1::2] = torch.cos(position * div_term) 

25 self.register_buffer("pe", pe) 

26 

27 def forward(self, x: Float[Tensor, "token batch embedding"]) -> Float[Tensor, "token batch embedding"]: 

28 """Positional encoding forward pass. 

29 

30 Arguments: 

31 x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` 

32 """ 

33 x = x + self.pe[:, : x.size(1)] 

34 return self.dropout(x) 

35 

36 

37class MultiScalePeakEmbedding(nn.Module): 

38 """Multi-scale sinusoidal embedding based on Voronov et. al.""" 

39 

40 def __init__(self, h_size: int, dropout: float = 0, float_dtype: torch.dtype | str = torch.float64) -> None: 

41 super().__init__() 

42 self.h_size = h_size 

43 self.float_dtype = getattr(torch, float_dtype, None) if isinstance(float_dtype, str) else float_dtype 

44 if self.float_dtype is None: 

45 raise ValueError(f"Unknown torch dtype string: {float_dtype}") 

46 

47 self.mlp = nn.Sequential( 

48 nn.Linear(h_size, h_size), 

49 nn.ReLU(), 

50 nn.Dropout(dropout), 

51 nn.Linear(h_size, h_size), 

52 nn.Dropout(dropout), 

53 ) 

54 

55 self.head = nn.Sequential( 

56 nn.Linear(h_size + 1, h_size), 

57 nn.ReLU(), 

58 nn.Dropout(dropout), 

59 nn.Linear(h_size, h_size), 

60 nn.Dropout(dropout), 

61 ) 

62 

63 freqs = 2 * np.pi / torch.logspace(-2, -3, int(h_size / 2), dtype=self.float_dtype) 

64 self.register_buffer("freqs", freqs) 

65 

66 # @torch.autocast("cuda", dtype=torch.float32) 

67 def forward(self, spectra: Float[Spectrum, " batch"]) -> Float[SpectrumEmbedding, " batch"]: 

68 """Encode peaks.""" 

69 mz_values, intensities = spectra[:, :, [0]], spectra[:, :, [1]] 

70 x = self.encode_mass(mz_values) 

71 x = self.mlp(x) 

72 x = torch.cat([x, intensities], axis=2) 

73 return self.head(x) 

74 

75 def encode_mass(self, x: Float[Tensor, " batch"]) -> Float[Tensor, "batch embedding"]: 

76 """Encode mz.""" 

77 x = self.freqs[None, None, :] * x 

78 x = torch.cat([torch.sin(x), torch.cos(x)], axis=2) 

79 return x.float() 

80 

81 

82class ConvPeakEmbedding(nn.Module): 

83 """Convolutional peak embedding.""" 

84 

85 def __init__(self, h_size: int, dropout: float = 0) -> None: 

86 super().__init__() 

87 self.h_size = h_size 

88 

89 self.conv = nn.Sequential( 

90 nn.Conv1d(1, h_size // 4, kernel_size=40_000, stride=100, padding=40_000 // 2 - 1), 

91 nn.ReLU(), 

92 nn.Dropout(), 

93 nn.Conv1d(h_size // 4, h_size, kernel_size=5, stride=1, padding=1), 

94 nn.ReLU(), 

95 nn.Dropout(), 

96 ) 

97 

98 def forward(self, x: Tensor) -> Tensor: 

99 """Conv peak embedding.""" 

100 x = x.unsqueeze(1) 

101 return self.conv(x).transpose(-1, -2)