Coverage for instanovo/inference/knapsack.py: 95%

82 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import bisect 

4import os 

5import pickle 

6from dataclasses import dataclass 

7 

8import numpy as np 

9 

10from instanovo.__init__ import console 

11from instanovo.constants import CARBON_MASS_DELTA 

12from instanovo.types import KnapsackChart, MassArray 

13from instanovo.utils.colorlogging import ColorLog 

14 

15logger = ColorLog(console, __name__).logger 

16 

17 

18@dataclass 

19class Knapsack: 

20 """A class that precomputes and stores a knapsack chart. 

21 

22 Args: 

23 max_mass (float): 

24 The maximum mass up to which the chart is 

25 calculated. 

26 

27 mass_scale (int): 

28 The scale in Daltons at which masses are 

29 calculated and rounded off. For example, 

30 a scale of 10000 would represent masses 

31 at a scale of 1e4 Da. 

32 

33 residues (list[str]): 

34 The list of residues that are considered 

35 in knapsack decoding. The order of this 

36 list is the inverse of `residue_indices`. 

37 

38 residue_indices (dict[str, int]): 

39 A mapping from residues as strings 

40 to indices in the knapsack chart. 

41 This is the inverse of `residues`. 

42 

43 masses (numpy.ndarray[number of masses]): 

44 The set of realisable masses in ascending order. 

45 

46 chart (numpy.ndarray[number of masses, number of residues]): 

47 The chart of realisable masses and residues that 

48 can lead to these masses. 

49 `chart[mass, residue]` is `True` if and only if 

50 a sequence of `mass` can be generated starting with 

51 the residue with index `residue`. 

52 """ 

53 

54 max_mass: float 

55 mass_scale: int 

56 max_isotope: int 

57 residues: list[str] 

58 residue_indices: dict[str, int] 

59 masses: MassArray 

60 chart: KnapsackChart 

61 

62 def save(self, path: str) -> None: 

63 """Save the knapsack file to a directory. 

64 

65 Args: 

66 path (str): 

67 The path to the directory. 

68 

69 Raises: 

70 FileExistsError: If the directory `path` already exists, 

71 this message raise an exception. 

72 """ 

73 if os.path.exists(path): 

74 raise FileExistsError 

75 

76 os.mkdir(path=path) 

77 parameters = ( 

78 self.max_mass, 

79 self.mass_scale, 

80 self.max_isotope, 

81 self.residues, 

82 self.residue_indices, 

83 ) 

84 pickle.dump(parameters, open(os.path.join(path, "parameters.pkl"), "wb")) 

85 np.save(os.path.join(path, "masses.npy"), self.masses) 

86 np.save(os.path.join(path, "chart.npy"), self.chart) 

87 

88 @classmethod 

89 def construct_knapsack( 

90 cls, 

91 residue_masses: dict[str, float], 

92 residue_indices: dict[str, int], 

93 max_mass: float, 

94 mass_scale: int, 

95 max_isotope: int = 2, 

96 ) -> "Knapsack": 

97 """Construct a knapsack chart using depth-first search. 

98 

99 Previous construction algorithms have used dynamic 

100 programming, but its space and time complexity 

101 scale linearly with mass resolution since every 

102 `possible` mass is iterated over rather than only the 

103 `feasible` masses. 

104 

105 Graph search algorithms only 

106 iterate over `feasible` masses which become a 

107 smaller and smaller share of possible masses as the 

108 mass resolution increases. This leads to dramatic 

109 performance improvements. 

110 

111 This implementation uses depth-first search since 

112 its agenda is a stack which can be implemented 

113 using python lists whose operations have amortized 

114 constant time complexity. 

115 

116 Args: 

117 residue_masses (dict[str, float]): 

118 A mapping from considered residues 

119 to their masses. 

120 

121 max_mass (float): 

122 The maximum mass up to which the chart is 

123 calculated. 

124 

125 mass_scale (int): 

126 The scale in Daltons at which masses are 

127 calculated and rounded off. For example, 

128 a scale of 10000 would represent masses 

129 at a scale of 1e4 Da. 

130 """ 

131 # Convert the maximum mass to units of the mass scale 

132 scaled_max_mass = round(max_mass * mass_scale) 

133 

134 logger.info("Scaling masses.") 

135 # Load residue information into appropriate data structures 

136 residues, scaled_residue_masses = [""], {} 

137 negative_residue_masses = [0.0] 

138 

139 for residue, mass in residue_masses.items(): 

140 residues.append(residue) 

141 if mass < 0: 

142 negative_residue_masses.append(mass) 

143 if abs(mass) > 0: 

144 scaled_residue_masses[residue] = round(mass * mass_scale) 

145 

146 # Initialize the search agenda 

147 mass_dim = round(max_mass * mass_scale) + 1 

148 residue_dim = max(residue_indices.values()) + 1 

149 chart = np.full((mass_dim, residue_dim), False) 

150 logger.info("Initializing chart.") 

151 agenda, visited = [], set() 

152 for residue, mass in scaled_residue_masses.items(): 

153 if mass < 0: 

154 chart[:, residue_indices[residue]] = True 

155 continue 

156 for negative_mass in negative_residue_masses: 

157 for isotope in range(0, max_isotope + 1, 1): 

158 offset = round((CARBON_MASS_DELTA * isotope + negative_mass) * mass_scale) 

159 agenda.append(mass + offset) 

160 chart[mass + offset, residue_indices[residue]] = True 

161 

162 # Perform depth-first search 

163 logger.info("Performing search.") 

164 while agenda: 

165 current_mass = agenda.pop() 

166 

167 if current_mass in visited: 

168 continue 

169 

170 for residue, mass in scaled_residue_masses.items(): 

171 if mass < 0: 

172 continue 

173 next_mass = current_mass + mass 

174 if next_mass <= scaled_max_mass: 

175 agenda.append(next_mass) 

176 chart[next_mass, residue_indices[residue]] = True 

177 visited.add(current_mass) 

178 

179 masses = np.array(sorted(visited)) 

180 return cls( 

181 max_mass=max_mass, 

182 mass_scale=mass_scale, 

183 max_isotope=max_isotope, 

184 residues=residues, 

185 residue_indices=residue_indices, 

186 masses=masses, 

187 chart=chart, 

188 ) 

189 

190 @classmethod 

191 def from_file(cls, path: str) -> "Knapsack": 

192 """Load a knapsack saved to a directory. 

193 

194 Args: 

195 path (str): 

196 The path to the directory. 

197 

198 Returns: 

199 _type_: _description_ 

200 """ 

201 max_mass, mass_scale, max_isotope, residues, residue_indices = pickle.load(open(os.path.join(path, "parameters.pkl"), "rb")) 

202 masses = np.load(os.path.join(path, "masses.npy")) 

203 chart = np.load(os.path.join(path, "chart.npy")) 

204 return cls( 

205 max_mass=max_mass, 

206 mass_scale=mass_scale, 

207 max_isotope=max_isotope, 

208 residues=residues, 

209 residue_indices=residue_indices, 

210 masses=masses, 

211 chart=chart, 

212 ) 

213 

214 def get_feasible_masses(self, target_mass: float, tolerance: float) -> list[int]: 

215 """Find a set of feasible masses for a given target mass and tolerance using binary search. 

216 

217 Args: 

218 target_mass (float): 

219 The masses to be decoded in Daltons. 

220 

221 tolerance (float): 

222 The mass tolerance in Daltons. 

223 

224 Returns: 

225 list[int]: 

226 A list of feasible masses. 

227 """ 

228 scaled_min_mass = round(self.mass_scale * (target_mass - tolerance)) 

229 scaled_max_mass = round(self.mass_scale * (target_mass + tolerance)) 

230 

231 left_endpoint = bisect.bisect_right(self.masses, scaled_min_mass) 

232 right_endpoint = bisect.bisect_left(self.masses, scaled_max_mass) 

233 

234 feasible_masses: list[int] = self.masses[left_endpoint:right_endpoint].tolist() 

235 return feasible_masses