Coverage for instanovo/inference/knapsack.py: 95%
82 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import bisect
4import os
5import pickle
6from dataclasses import dataclass
8import numpy as np
10from instanovo.__init__ import console
11from instanovo.constants import CARBON_MASS_DELTA
12from instanovo.types import KnapsackChart, MassArray
13from instanovo.utils.colorlogging import ColorLog
15logger = ColorLog(console, __name__).logger
18@dataclass
19class Knapsack:
20 """A class that precomputes and stores a knapsack chart.
22 Args:
23 max_mass (float):
24 The maximum mass up to which the chart is
25 calculated.
27 mass_scale (int):
28 The scale in Daltons at which masses are
29 calculated and rounded off. For example,
30 a scale of 10000 would represent masses
31 at a scale of 1e4 Da.
33 residues (list[str]):
34 The list of residues that are considered
35 in knapsack decoding. The order of this
36 list is the inverse of `residue_indices`.
38 residue_indices (dict[str, int]):
39 A mapping from residues as strings
40 to indices in the knapsack chart.
41 This is the inverse of `residues`.
43 masses (numpy.ndarray[number of masses]):
44 The set of realisable masses in ascending order.
46 chart (numpy.ndarray[number of masses, number of residues]):
47 The chart of realisable masses and residues that
48 can lead to these masses.
49 `chart[mass, residue]` is `True` if and only if
50 a sequence of `mass` can be generated starting with
51 the residue with index `residue`.
52 """
54 max_mass: float
55 mass_scale: int
56 max_isotope: int
57 residues: list[str]
58 residue_indices: dict[str, int]
59 masses: MassArray
60 chart: KnapsackChart
62 def save(self, path: str) -> None:
63 """Save the knapsack file to a directory.
65 Args:
66 path (str):
67 The path to the directory.
69 Raises:
70 FileExistsError: If the directory `path` already exists,
71 this message raise an exception.
72 """
73 if os.path.exists(path):
74 raise FileExistsError
76 os.mkdir(path=path)
77 parameters = (
78 self.max_mass,
79 self.mass_scale,
80 self.max_isotope,
81 self.residues,
82 self.residue_indices,
83 )
84 pickle.dump(parameters, open(os.path.join(path, "parameters.pkl"), "wb"))
85 np.save(os.path.join(path, "masses.npy"), self.masses)
86 np.save(os.path.join(path, "chart.npy"), self.chart)
88 @classmethod
89 def construct_knapsack(
90 cls,
91 residue_masses: dict[str, float],
92 residue_indices: dict[str, int],
93 max_mass: float,
94 mass_scale: int,
95 max_isotope: int = 2,
96 ) -> "Knapsack":
97 """Construct a knapsack chart using depth-first search.
99 Previous construction algorithms have used dynamic
100 programming, but its space and time complexity
101 scale linearly with mass resolution since every
102 `possible` mass is iterated over rather than only the
103 `feasible` masses.
105 Graph search algorithms only
106 iterate over `feasible` masses which become a
107 smaller and smaller share of possible masses as the
108 mass resolution increases. This leads to dramatic
109 performance improvements.
111 This implementation uses depth-first search since
112 its agenda is a stack which can be implemented
113 using python lists whose operations have amortized
114 constant time complexity.
116 Args:
117 residue_masses (dict[str, float]):
118 A mapping from considered residues
119 to their masses.
121 max_mass (float):
122 The maximum mass up to which the chart is
123 calculated.
125 mass_scale (int):
126 The scale in Daltons at which masses are
127 calculated and rounded off. For example,
128 a scale of 10000 would represent masses
129 at a scale of 1e4 Da.
130 """
131 # Convert the maximum mass to units of the mass scale
132 scaled_max_mass = round(max_mass * mass_scale)
134 logger.info("Scaling masses.")
135 # Load residue information into appropriate data structures
136 residues, scaled_residue_masses = [""], {}
137 negative_residue_masses = [0.0]
139 for residue, mass in residue_masses.items():
140 residues.append(residue)
141 if mass < 0:
142 negative_residue_masses.append(mass)
143 if abs(mass) > 0:
144 scaled_residue_masses[residue] = round(mass * mass_scale)
146 # Initialize the search agenda
147 mass_dim = round(max_mass * mass_scale) + 1
148 residue_dim = max(residue_indices.values()) + 1
149 chart = np.full((mass_dim, residue_dim), False)
150 logger.info("Initializing chart.")
151 agenda, visited = [], set()
152 for residue, mass in scaled_residue_masses.items():
153 if mass < 0:
154 chart[:, residue_indices[residue]] = True
155 continue
156 for negative_mass in negative_residue_masses:
157 for isotope in range(0, max_isotope + 1, 1):
158 offset = round((CARBON_MASS_DELTA * isotope + negative_mass) * mass_scale)
159 agenda.append(mass + offset)
160 chart[mass + offset, residue_indices[residue]] = True
162 # Perform depth-first search
163 logger.info("Performing search.")
164 while agenda:
165 current_mass = agenda.pop()
167 if current_mass in visited:
168 continue
170 for residue, mass in scaled_residue_masses.items():
171 if mass < 0:
172 continue
173 next_mass = current_mass + mass
174 if next_mass <= scaled_max_mass:
175 agenda.append(next_mass)
176 chart[next_mass, residue_indices[residue]] = True
177 visited.add(current_mass)
179 masses = np.array(sorted(visited))
180 return cls(
181 max_mass=max_mass,
182 mass_scale=mass_scale,
183 max_isotope=max_isotope,
184 residues=residues,
185 residue_indices=residue_indices,
186 masses=masses,
187 chart=chart,
188 )
190 @classmethod
191 def from_file(cls, path: str) -> "Knapsack":
192 """Load a knapsack saved to a directory.
194 Args:
195 path (str):
196 The path to the directory.
198 Returns:
199 _type_: _description_
200 """
201 max_mass, mass_scale, max_isotope, residues, residue_indices = pickle.load(open(os.path.join(path, "parameters.pkl"), "rb"))
202 masses = np.load(os.path.join(path, "masses.npy"))
203 chart = np.load(os.path.join(path, "chart.npy"))
204 return cls(
205 max_mass=max_mass,
206 mass_scale=mass_scale,
207 max_isotope=max_isotope,
208 residues=residues,
209 residue_indices=residue_indices,
210 masses=masses,
211 chart=chart,
212 )
214 def get_feasible_masses(self, target_mass: float, tolerance: float) -> list[int]:
215 """Find a set of feasible masses for a given target mass and tolerance using binary search.
217 Args:
218 target_mass (float):
219 The masses to be decoded in Daltons.
221 tolerance (float):
222 The mass tolerance in Daltons.
224 Returns:
225 list[int]:
226 A list of feasible masses.
227 """
228 scaled_min_mass = round(self.mass_scale * (target_mass - tolerance))
229 scaled_max_mass = round(self.mass_scale * (target_mass + tolerance))
231 left_endpoint = bisect.bisect_right(self.masses, scaled_min_mass)
232 right_endpoint = bisect.bisect_left(self.masses, scaled_max_mass)
234 feasible_masses: list[int] = self.masses[left_endpoint:right_endpoint].tolist()
235 return feasible_masses