Coverage for instanovo/inference/interfaces.py: 81%

43 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from abc import ABCMeta, abstractmethod 

4from dataclasses import dataclass 

5from typing import Any 

6 

7import torch 

8from jaxtyping import Float, Integer 

9 

10from instanovo.types import Peptide, PrecursorFeatures, Spectrum 

11from instanovo.utils.residues import ResidueSet 

12 

13 

14@dataclass 

15class ScoredSequence: 

16 """This class holds a residue sequence and its log probability.""" 

17 

18 sequence: list[str] 

19 mass_error: float 

20 sequence_log_probability: float 

21 token_log_probabilities: list[float] 

22 

23 

24class Decodable(metaclass=ABCMeta): 

25 """An interface for models that can be decoded. 

26 

27 Algorithms should conform to the search interface. 

28 """ 

29 

30 @property 

31 @abstractmethod 

32 def residue_set(self) -> ResidueSet: 

33 """Every model must have a `residue_set` attribute.""" 

34 pass 

35 

36 @abstractmethod 

37 def init( # type:ignore 

38 self, 

39 spectra: Float[Spectrum, " batch"], 

40 precursors: Float[PrecursorFeatures, " batch"], 

41 *args, 

42 **kwargs, 

43 ) -> Any: 

44 """Initialize the search state. 

45 

46 Args: 

47 spectra (torch.FloatTensor): 

48 The spectra to be sequenced. 

49 

50 precursors (torch.FloatTensor[batch size, 3]): 

51 The precursor mass, charge and mass-to-charge ratio. 

52 """ 

53 pass 

54 

55 @abstractmethod 

56 def score_candidates( # type:ignore 

57 self, 

58 sequences: Integer[Peptide, "..."], 

59 precursor_mass_charge: Float[PrecursorFeatures, "..."], 

60 *args, 

61 **kwargs, 

62 ) -> torch.FloatTensor: 

63 """Generate and score the next set of candidates. 

64 

65 Args: 

66 sequences (torch.LongTensor): 

67 Partial residue sequences in generated 

68 the course of decoding. 

69 

70 precursor_mass_charge (torch.FloatTensor[batch size, 3]): 

71 The precursor mass, charge and mass-to-charge ratio. 

72 """ 

73 pass 

74 

75 @abstractmethod 

76 def get_residue_masses(self, mass_scale: int) -> torch.LongTensor: 

77 """Get residue masses for the model's residue vocabulary. 

78 

79 Args: 

80 mass_scale (int): 

81 The scale in Daltons at which masses are 

82 calculated and rounded off. For example, 

83 a scale of 10000 would represent masses 

84 at a scale of 1e4 Da. 

85 """ 

86 pass 

87 

88 @abstractmethod 

89 def decode(self, sequence: Integer[Peptide, "..."]) -> list[str]: 

90 """Map sequences of indices to residues using the model's residue vocabulary. 

91 

92 Args: 

93 sequence (torch.LongTensor): 

94 The sequence of residue indices to be mapped 

95 to the corresponding residue strings. 

96 """ 

97 pass 

98 

99 @abstractmethod 

100 def get_eos_index(self) -> int: 

101 """Get the end of sequence token's index in the model's residue vocabulary.""" 

102 pass 

103 

104 @abstractmethod 

105 def get_empty_index(self) -> int: 

106 """Get the empty token's index in the model's residue vocabulary.""" 

107 pass 

108 

109 

110class Decoder(metaclass=ABCMeta): 

111 """A class that implements some search algorithm for decoding. 

112 

113 Model should conform to the `Decodable` interface. 

114 

115 Args: 

116 model (Decodable): 

117 The model to predict residue sequences 

118 from using the implemented search 

119 algorithm. 

120 """ 

121 

122 def __init__(self, model: Decodable): 

123 self.model = model 

124 

125 @abstractmethod 

126 def decode( # type:ignore 

127 self, 

128 spectra: Float[Spectrum, "..."], 

129 precursors: Float[PrecursorFeatures, "..."], 

130 *args, 

131 **kwargs, 

132 ) -> dict[str, Any]: 

133 """Generate the predicted residue sequence using the decoder's search algorithm. 

134 

135 Args: 

136 spectra (torch.FloatTensor): 

137 The spectra to be sequenced. 

138 

139 precursors (torch.FloatTensor): 

140 The precursor mass, charge and mass-to-charge ratio. 

141 

142 Returns: 

143 dict[str, Any]: 

144 Required keys: 

145 - "sequence": list[str] 

146 - "mass_error": float 

147 - "sequence_log_probability": float 

148 - "token_log_probabilities": list[float] 

149 - "encoder_output": list[float] (optional) 

150 Example additional keys: 

151 - "sequence_beam_0": list[str] 

152 """ 

153 pass