Coverage for instanovo/utils/residues.py: 96%

68 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import re 

4 

5import numpy as np 

6import torch 

7 

8from instanovo.constants import H2O_MASS, PROTON_MASS_AMU, SpecialTokens 

9 

10 

11class ResidueSet: 

12 """A class for managing sets of residues. 

13 

14 Args: 

15 residue_masses (dict[str, float]): 

16 Dictionary of residues mapping to corresponding mass values. 

17 residue_remapping (dict[str, str] | None, optional): 

18 Dictionary of residues mapping to keys in `residue_masses`. 

19 This is used for dataset specific residue naming conventions. 

20 Residue remapping may be many-to-one. 

21 """ 

22 

23 def __init__( 

24 self, 

25 residue_masses: dict[str, float], 

26 residue_remapping: dict[str, str] | None = None, 

27 ) -> None: 

28 self.residue_masses = residue_masses 

29 self.residue_remapping = residue_remapping if residue_remapping else {} 

30 

31 # Special tokens come first 

32 self.special_tokens = [ 

33 SpecialTokens.PAD_TOKEN.value, 

34 SpecialTokens.SOS_TOKEN.value, 

35 SpecialTokens.EOS_TOKEN.value, 

36 ] 

37 

38 self.vocab = self.special_tokens + list(self.residue_masses.keys()) 

39 

40 # Create mappings 

41 self.residue_to_index = {residue: index for index, residue in enumerate(self.vocab)} 

42 self.index_to_residue = dict(enumerate(self.vocab)) 

43 # Split on amino acids allowing for modifications eg. AM(ox)Z -> [A, M(ox), Z] 

44 # Supports brackets or unimod notation 

45 self.tokenizer_regex = ( 

46 # First capture group: matches standalone modifications 

47 # These would represent n-terminal modifications: 

48 # - Anything in square brackets, like [UNIMOD:35], [GLY:123456], [+.98] 

49 # - Anything in parentheses, like (ox), (-17.02), (+.98), (p) 

50 # - Raw numeric modifications with optional sign, like +15.99, -17.02, 0.98, .98 

51 r"(\[[^\]]+\]" # Square-bracketed mod, e.g. [UNIMOD:123] 

52 r"|\([^)]+\)" # Parentheses mod, e.g. (+15.99) 

53 r"|[+-]?\d+(?:\.\d+)?" # Number with optional +/-, e.g. +15.99, -17.02, 42.02 

54 r"|[+-]?\.\d+" # Decimal without leading digit, e.g. .98 

55 r")|" 

56 # Second capture group: matches amino acids with optional attached modifications 

57 # - Starts with a capital letter A-Z (standard amino acid codes) 

58 # - Optionally followed by any of the above modification formats 

59 r"([A-Z]" # Amino acid residue, e.g. A, R, C 

60 r"(?:\[[^\]]+\]" # Optional square-bracketed mod 

61 r"|\([^)]+\)" # Optional parentheses mod 

62 r"|[+-]?\d+(?:\.\d+)?" # Optional numeric mod 

63 r"|[+-]?\.\d+)?" 

64 r")" 

65 ) 

66 

67 self.PAD_INDEX: int = self.residue_to_index[SpecialTokens.PAD_TOKEN.value] 

68 self.SOS_INDEX: int = self.residue_to_index[SpecialTokens.SOS_TOKEN.value] 

69 self.EOS_INDEX: int = self.residue_to_index[SpecialTokens.EOS_TOKEN.value] 

70 

71 # TODO: Add support for specifying which residues are n-terminal only. 

72 

73 def update_remapping(self, mapping: dict[str, str]) -> None: 

74 """Update the residue remapping for specific datasets. 

75 

76 Args: 

77 mapping (dict[str, str]): 

78 The mapping from residues specific to a dataset 

79 to residues in the original `residue_masses`. 

80 """ 

81 self.residue_remapping.update(mapping) 

82 

83 def get_mass(self, residue: str) -> float: 

84 """Get the mass of a residue. 

85 

86 Args: 

87 residue (str): 

88 The residue whose mass to fetch. This residue 

89 must be in the residue set or this will raise 

90 a `KeyError`. 

91 

92 Returns: 

93 float: The mass of the residue in Daltons. 

94 """ 

95 if self.residue_remapping and residue in self.residue_remapping: 

96 residue = self.residue_remapping[residue] 

97 return self.residue_masses[residue] 

98 

99 def get_sequence_mass(self, sequence: str | list[str], charge: int | None) -> float: 

100 """Get the mass of a residue sequence. 

101 

102 Args: 

103 sequence (str): 

104 The residue sequence whose mass to calculate. 

105 All residues must be in the residue set or 

106 this will raise a `KeyError`. 

107 charge (int | None, optional): 

108 Charge of the sequence to calculate the mass. 

109 

110 Returns: 

111 float: The mass of the residue in Daltons. 

112 If a charge is specified, returns m/z. 

113 """ 

114 mass = sum([self.get_mass(residue) for residue in sequence]) + H2O_MASS 

115 if charge: 

116 mass = (mass / charge) + PROTON_MASS_AMU 

117 return float(mass) 

118 

119 def tokenize(self, sequence: str | list[str] | None) -> list[str]: 

120 """Split a peptide represented as a string into a list of residues. 

121 

122 Args: 

123 sequence (str | list[str] | None): The peptide to be split. 

124 

125 Returns: 

126 list[str]: The sequence of residues forming the peptide. 

127 """ 

128 # return re.split(self.tokenizer_regex, sequence) 

129 # TODO: find a way to handle N-terminal PTMs appearing at any position 

130 if sequence is None: 

131 return [] 

132 if isinstance(sequence, list): 

133 return sequence 

134 return [item for sublist in re.findall(self.tokenizer_regex, sequence) for item in sublist if item] 

135 

136 def detokenize(self, sequence: list[str]) -> str: 

137 """Joining a list of residues into a string representing the peptide. 

138 

139 Args: 

140 sequence (list[str]): 

141 The sequence of residues. 

142 

143 Returns: 

144 str: 

145 The string representing the peptide. 

146 """ 

147 return "".join(sequence) 

148 

149 def encode( 

150 self, 

151 sequence: list[str], 

152 add_eos: bool = False, 

153 return_tensor: str | None = None, 

154 pad_length: int | None = None, 

155 ) -> torch.LongTensor | np.ndarray: 

156 """Map a sequence of residues to their indices and optionally pad them to a fixed length. 

157 

158 Args: 

159 sequence (list[str]): 

160 The sequence of residues. 

161 add_eos (bool): 

162 Add an EOS token when encoding. 

163 Defaults to `False`. 

164 return_tensor (str | None, optional): 

165 Return type of encoded tensor. Returns a list if integers 

166 if no return type is specified. Options: None, pt, np 

167 pad_length (int | None, optional): 

168 An optional fixed length to pad the encoded sequence to. 

169 If this is `None`, no padding is done. 

170 

171 Returns: 

172 torch.LongTensor | np.ndarray: 

173 A tensor with the indices of the residues. 

174 """ 

175 encoded_list = [ 

176 self.residue_to_index[ 

177 # remap the residue if possible 

178 self.residue_remapping[residue] if residue in self.residue_remapping else residue 

179 ] 

180 for residue in sequence 

181 ] 

182 

183 if add_eos: 

184 encoded_list.extend([self.EOS_INDEX]) 

185 

186 if pad_length: 

187 encoded_list.extend((pad_length - len(encoded_list)) * [self.PAD_INDEX]) 

188 

189 if return_tensor == "pt": 

190 return torch.tensor(encoded_list, dtype=torch.long) 

191 elif return_tensor == "np": 

192 return np.array(encoded_list, dtype=np.int32) 

193 else: 

194 return encoded_list 

195 

196 def decode(self, sequence: torch.LongTensor | list[int], reverse: bool = False) -> list[str]: 

197 """Map a sequence of indices to the corresponding sequence of residues. 

198 

199 Args: 

200 sequence (torch.LongTensor | list[int]): 

201 The sequence of residue indices. 

202 reverse (bool): 

203 Optionally reverse the decoded sequence. 

204 

205 Returns: 

206 list[str]: 

207 The corresponding sequence of residue strings. 

208 """ 

209 if isinstance(sequence, torch.Tensor): 

210 sequence = sequence.cpu().numpy() 

211 

212 residue_sequence = [] 

213 for index in sequence: 

214 if index == self.EOS_INDEX: 

215 break 

216 if index == self.SOS_INDEX or index == self.PAD_INDEX: 

217 continue 

218 residue_sequence.append(index) 

219 

220 if reverse: 

221 residue_sequence = residue_sequence[::-1] 

222 

223 return [self.index_to_residue[index] for index in residue_sequence] 

224 

225 def __len__(self) -> int: 

226 return len(self.index_to_residue) 

227 

228 def __eq__(self, other: object) -> bool: 

229 if not isinstance(other, ResidueSet): 

230 return NotImplemented 

231 return self.vocab == other.vocab 

232 

233 def __contains__(self, residue: str) -> bool: 

234 """Check if a residue is in the residue set. 

235 

236 Args: 

237 residue (str): The residue to check. 

238 

239 Returns: 

240 bool: True if the residue is in the residue set, False otherwise. 

241 """ 

242 return residue in self.residue_masses