Coverage for instanovo/utils/residues.py: 96%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import re
5import numpy as np
6import torch
8from instanovo.constants import H2O_MASS, PROTON_MASS_AMU, SpecialTokens
11class ResidueSet:
12 """A class for managing sets of residues.
14 Args:
15 residue_masses (dict[str, float]):
16 Dictionary of residues mapping to corresponding mass values.
17 residue_remapping (dict[str, str] | None, optional):
18 Dictionary of residues mapping to keys in `residue_masses`.
19 This is used for dataset specific residue naming conventions.
20 Residue remapping may be many-to-one.
21 """
23 def __init__(
24 self,
25 residue_masses: dict[str, float],
26 residue_remapping: dict[str, str] | None = None,
27 ) -> None:
28 self.residue_masses = residue_masses
29 self.residue_remapping = residue_remapping if residue_remapping else {}
31 # Special tokens come first
32 self.special_tokens = [
33 SpecialTokens.PAD_TOKEN.value,
34 SpecialTokens.SOS_TOKEN.value,
35 SpecialTokens.EOS_TOKEN.value,
36 ]
38 self.vocab = self.special_tokens + list(self.residue_masses.keys())
40 # Create mappings
41 self.residue_to_index = {residue: index for index, residue in enumerate(self.vocab)}
42 self.index_to_residue = dict(enumerate(self.vocab))
43 # Split on amino acids allowing for modifications eg. AM(ox)Z -> [A, M(ox), Z]
44 # Supports brackets or unimod notation
45 self.tokenizer_regex = (
46 # First capture group: matches standalone modifications
47 # These would represent n-terminal modifications:
48 # - Anything in square brackets, like [UNIMOD:35], [GLY:123456], [+.98]
49 # - Anything in parentheses, like (ox), (-17.02), (+.98), (p)
50 # - Raw numeric modifications with optional sign, like +15.99, -17.02, 0.98, .98
51 r"(\[[^\]]+\]" # Square-bracketed mod, e.g. [UNIMOD:123]
52 r"|\([^)]+\)" # Parentheses mod, e.g. (+15.99)
53 r"|[+-]?\d+(?:\.\d+)?" # Number with optional +/-, e.g. +15.99, -17.02, 42.02
54 r"|[+-]?\.\d+" # Decimal without leading digit, e.g. .98
55 r")|"
56 # Second capture group: matches amino acids with optional attached modifications
57 # - Starts with a capital letter A-Z (standard amino acid codes)
58 # - Optionally followed by any of the above modification formats
59 r"([A-Z]" # Amino acid residue, e.g. A, R, C
60 r"(?:\[[^\]]+\]" # Optional square-bracketed mod
61 r"|\([^)]+\)" # Optional parentheses mod
62 r"|[+-]?\d+(?:\.\d+)?" # Optional numeric mod
63 r"|[+-]?\.\d+)?"
64 r")"
65 )
67 self.PAD_INDEX: int = self.residue_to_index[SpecialTokens.PAD_TOKEN.value]
68 self.SOS_INDEX: int = self.residue_to_index[SpecialTokens.SOS_TOKEN.value]
69 self.EOS_INDEX: int = self.residue_to_index[SpecialTokens.EOS_TOKEN.value]
71 # TODO: Add support for specifying which residues are n-terminal only.
73 def update_remapping(self, mapping: dict[str, str]) -> None:
74 """Update the residue remapping for specific datasets.
76 Args:
77 mapping (dict[str, str]):
78 The mapping from residues specific to a dataset
79 to residues in the original `residue_masses`.
80 """
81 self.residue_remapping.update(mapping)
83 def get_mass(self, residue: str) -> float:
84 """Get the mass of a residue.
86 Args:
87 residue (str):
88 The residue whose mass to fetch. This residue
89 must be in the residue set or this will raise
90 a `KeyError`.
92 Returns:
93 float: The mass of the residue in Daltons.
94 """
95 if self.residue_remapping and residue in self.residue_remapping:
96 residue = self.residue_remapping[residue]
97 return self.residue_masses[residue]
99 def get_sequence_mass(self, sequence: str | list[str], charge: int | None) -> float:
100 """Get the mass of a residue sequence.
102 Args:
103 sequence (str):
104 The residue sequence whose mass to calculate.
105 All residues must be in the residue set or
106 this will raise a `KeyError`.
107 charge (int | None, optional):
108 Charge of the sequence to calculate the mass.
110 Returns:
111 float: The mass of the residue in Daltons.
112 If a charge is specified, returns m/z.
113 """
114 mass = sum([self.get_mass(residue) for residue in sequence]) + H2O_MASS
115 if charge:
116 mass = (mass / charge) + PROTON_MASS_AMU
117 return float(mass)
119 def tokenize(self, sequence: str | list[str] | None) -> list[str]:
120 """Split a peptide represented as a string into a list of residues.
122 Args:
123 sequence (str | list[str] | None): The peptide to be split.
125 Returns:
126 list[str]: The sequence of residues forming the peptide.
127 """
128 # return re.split(self.tokenizer_regex, sequence)
129 # TODO: find a way to handle N-terminal PTMs appearing at any position
130 if sequence is None:
131 return []
132 if isinstance(sequence, list):
133 return sequence
134 return [item for sublist in re.findall(self.tokenizer_regex, sequence) for item in sublist if item]
136 def detokenize(self, sequence: list[str]) -> str:
137 """Joining a list of residues into a string representing the peptide.
139 Args:
140 sequence (list[str]):
141 The sequence of residues.
143 Returns:
144 str:
145 The string representing the peptide.
146 """
147 return "".join(sequence)
149 def encode(
150 self,
151 sequence: list[str],
152 add_eos: bool = False,
153 return_tensor: str | None = None,
154 pad_length: int | None = None,
155 ) -> torch.LongTensor | np.ndarray:
156 """Map a sequence of residues to their indices and optionally pad them to a fixed length.
158 Args:
159 sequence (list[str]):
160 The sequence of residues.
161 add_eos (bool):
162 Add an EOS token when encoding.
163 Defaults to `False`.
164 return_tensor (str | None, optional):
165 Return type of encoded tensor. Returns a list if integers
166 if no return type is specified. Options: None, pt, np
167 pad_length (int | None, optional):
168 An optional fixed length to pad the encoded sequence to.
169 If this is `None`, no padding is done.
171 Returns:
172 torch.LongTensor | np.ndarray:
173 A tensor with the indices of the residues.
174 """
175 encoded_list = [
176 self.residue_to_index[
177 # remap the residue if possible
178 self.residue_remapping[residue] if residue in self.residue_remapping else residue
179 ]
180 for residue in sequence
181 ]
183 if add_eos:
184 encoded_list.extend([self.EOS_INDEX])
186 if pad_length:
187 encoded_list.extend((pad_length - len(encoded_list)) * [self.PAD_INDEX])
189 if return_tensor == "pt":
190 return torch.tensor(encoded_list, dtype=torch.long)
191 elif return_tensor == "np":
192 return np.array(encoded_list, dtype=np.int32)
193 else:
194 return encoded_list
196 def decode(self, sequence: torch.LongTensor | list[int], reverse: bool = False) -> list[str]:
197 """Map a sequence of indices to the corresponding sequence of residues.
199 Args:
200 sequence (torch.LongTensor | list[int]):
201 The sequence of residue indices.
202 reverse (bool):
203 Optionally reverse the decoded sequence.
205 Returns:
206 list[str]:
207 The corresponding sequence of residue strings.
208 """
209 if isinstance(sequence, torch.Tensor):
210 sequence = sequence.cpu().numpy()
212 residue_sequence = []
213 for index in sequence:
214 if index == self.EOS_INDEX:
215 break
216 if index == self.SOS_INDEX or index == self.PAD_INDEX:
217 continue
218 residue_sequence.append(index)
220 if reverse:
221 residue_sequence = residue_sequence[::-1]
223 return [self.index_to_residue[index] for index in residue_sequence]
225 def __len__(self) -> int:
226 return len(self.index_to_residue)
228 def __eq__(self, other: object) -> bool:
229 if not isinstance(other, ResidueSet):
230 return NotImplemented
231 return self.vocab == other.vocab
233 def __contains__(self, residue: str) -> bool:
234 """Check if a residue is in the residue set.
236 Args:
237 residue (str): The residue to check.
239 Returns:
240 bool: True if the residue is in the residue set, False otherwise.
241 """
242 return residue in self.residue_masses