Coverage for instanovo/transformer/data.py: 88%

111 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Dict 

4 

5import numpy as np 

6import spectrum_utils.spectrum as sus 

7import torch 

8from jaxtyping import Float 

9from torch import Tensor 

10 

11from instanovo.__init__ import console 

12from instanovo.common import DataProcessor 

13from instanovo.constants import ANNOTATED_COLUMN, PROTON_MASS_AMU, MSColumns 

14from instanovo.utils.colorlogging import ColorLog 

15from instanovo.utils.residues import ResidueSet 

16 

17logger = ColorLog(console, __name__).logger 

18 

19 

20class TransformerDataProcessor(DataProcessor): 

21 """Transformer implementation of theDataProcessor class. 

22 

23 Includes methods to process spectra and peptides for auto-regressive 

24 de novo peptide sequencing. 

25 """ 

26 

27 def __init__( 

28 self, 

29 residue_set: ResidueSet, 

30 n_peaks: int = 200, 

31 min_mz: float = 50.0, 

32 max_mz: float = 2500.0, 

33 min_intensity: float = 0.01, 

34 remove_precursor_tol: float = 2.0, 

35 reverse_peptide: bool = True, 

36 annotated: bool = True, 

37 return_str: bool = False, 

38 add_eos: bool = True, 

39 use_spectrum_utils: bool = True, 

40 metadata_columns: list[str] | None = None, 

41 ) -> None: 

42 """Initialize the data processor. 

43 

44 Args: 

45 residue_set (ResidueSet): The residue set to use. 

46 n_peaks (int): The number of peaks to keep in the spectrum. 

47 min_mz (float): The minimum m/z to keep in the spectrum. 

48 max_mz (float): The maximum m/z to keep in the spectrum. 

49 min_intensity (float): The minimum intensity to keep in the spectrum. 

50 remove_precursor_tol (float): The tolerance to remove the precursor peak in Da. 

51 reverse_peptide (bool): Whether to reverse the peptide. 

52 annotated (bool): Whether the dataset is annotated. 

53 return_str (bool): Whether to return the peptide as a string. 

54 add_eos (bool): Whether to add the end of sequence token. 

55 use_spectrum_utils (bool): Whether to use the spectrum_utils library to process the spectra. 

56 metadata_columns (list[str] | None): The metadata columns to add to the dataset. 

57 """ 

58 super().__init__(metadata_columns=metadata_columns) 

59 self.residue_set = residue_set 

60 self.n_peaks = n_peaks 

61 self.min_mz = min_mz 

62 self.max_mz = max_mz 

63 self.min_intensity = min_intensity 

64 self.remove_precursor_tol = remove_precursor_tol 

65 self.reverse_peptide = reverse_peptide 

66 self.annotated = annotated 

67 self.return_str = return_str 

68 self.add_eos = add_eos 

69 self.use_spectrum_utils = use_spectrum_utils 

70 

71 def _process_spectrum( 

72 self, 

73 mz_array: Float[Tensor, " peak"], 

74 int_array: Float[Tensor, " peak"], 

75 precursor_mz: float, 

76 precursor_charge: int, 

77 ) -> torch.tensor: 

78 """Process a single spectrum. 

79 

80 Args: 

81 mz_array (Float[Tensor, " peak"]): The m/z array of the spectrum. 

82 int_array (Float[Tensor, " peak"]): The intensity array of the spectrum. 

83 precursor_mz (float): The precursor m/z. 

84 precursor_charge (int): The precursor charge. 

85 

86 Returns: 

87 torch.tensor: The processed spectrum. 

88 """ 

89 if self.use_spectrum_utils: 

90 spectrum = sus.MsmsSpectrum( 

91 "", 

92 precursor_mz, 

93 precursor_charge, 

94 np.asarray(mz_array).astype(np.float32), 

95 np.asarray(int_array).astype(np.float32), 

96 ) 

97 try: 

98 spectrum.set_mz_range(self.min_mz, self.max_mz) 

99 if len(spectrum.mz) == 0: 

100 raise ValueError 

101 spectrum.remove_precursor_peak(self.remove_precursor_tol, "Da") 

102 if len(spectrum.mz) == 0: 

103 raise ValueError 

104 spectrum.filter_intensity(self.min_intensity, self.n_peaks) 

105 if len(spectrum.mz) == 0: 

106 raise ValueError 

107 spectrum.scale_intensity("root", 1) 

108 intensities = spectrum.intensity / np.linalg.norm(spectrum.intensity) 

109 return torch.tensor(np.asarray([spectrum.mz, intensities])).T.float() 

110 except ValueError: 

111 # Replace invalid spectra by a dummy spectrum. 

112 return torch.tensor([[0, 1]]).float() 

113 

114 # Fallback implementation that matches spectrum_utils functionality 

115 try: 

116 # 1. Set m/z range 

117 mask = (mz_array >= self.min_mz) & (mz_array <= self.max_mz) 

118 mz_array = mz_array[mask] 

119 int_array = int_array[mask] 

120 

121 if len(mz_array) == 0: 

122 raise ValueError 

123 

124 # 2. Remove precursor peak 

125 precursor_mask = torch.abs(mz_array - precursor_mz) > self.remove_precursor_tol 

126 mz_array = mz_array[precursor_mask] 

127 int_array = int_array[precursor_mask] 

128 

129 if len(mz_array) == 0: 

130 raise ValueError 

131 

132 # 3. Filter by intensity and keep top n_peaks 

133 intensity_mask = int_array >= self.min_intensity 

134 mz_array = mz_array[intensity_mask] 

135 int_array = int_array[intensity_mask] 

136 

137 if len(mz_array) == 0: 

138 raise ValueError 

139 

140 # Get top n_peaks by intensity 

141 if len(mz_array) > self.n_peaks: 

142 _, indices = torch.topk(int_array, self.n_peaks) 

143 mz_array = mz_array[indices] 

144 int_array = int_array[indices] 

145 

146 # 4. Scale intensity (root scaling) 

147 int_array = torch.sqrt(int_array) 

148 

149 # 5. Normalize intensities 

150 int_array = int_array / torch.linalg.norm(int_array) 

151 

152 return torch.stack([mz_array, int_array], dim=1).float() 

153 

154 except ValueError: 

155 # Replace invalid spectra by a dummy spectrum 

156 return torch.tensor([[0.0, 1.0]], dtype=torch.float32) 

157 

158 def process_row(self, row: Dict[str, Any]) -> Dict[str, Any]: 

159 """Process a single row of data for auto-regressive de novo peptide sequencing. 

160 

161 Args: 

162 row (dict[str, Any]): The row of data to process in dict format. 

163 

164 Returns: 

165 dict[str, Any]: The processed row with resulting columns. 

166 """ 

167 processed = {} 

168 

169 # Spectra processing 

170 spectra = self._process_spectrum( 

171 torch.tensor(row[MSColumns.MZ_ARRAY.value]), 

172 torch.tensor(row[MSColumns.INTENSITY_ARRAY.value]), 

173 row[MSColumns.PRECURSOR_MZ.value], 

174 row[MSColumns.PRECURSOR_CHARGE.value], 

175 ) 

176 

177 processed["spectra"] = spectra 

178 

179 # Peptide processing 

180 if self.annotated: 

181 if ANNOTATED_COLUMN not in row: 

182 raise KeyError(f"Annotated column {ANNOTATED_COLUMN} not found in dataset.") 

183 peptide = row[ANNOTATED_COLUMN] 

184 if not self.return_str: 

185 peptide_tokenized = self.residue_set.tokenize(peptide) 

186 

187 if self.reverse_peptide: 

188 peptide_tokenized = peptide_tokenized[::-1] 

189 

190 peptide_encoding = self.residue_set.encode(peptide_tokenized, add_eos=self.add_eos, return_tensor="pt") 

191 

192 processed["peptide"] = peptide_encoding 

193 else: 

194 processed["peptide"] = peptide 

195 

196 return processed 

197 

198 def _get_expected_columns(self) -> list[str]: 

199 """Get the expected columns. 

200 

201 These are the columns that will be returned by the `process_row` method. 

202 

203 Returns: 

204 list[str]: The expected columns. 

205 """ 

206 expected_columns = ["spectra", "precursor_mz", "precursor_charge"] 

207 if self.annotated: 

208 expected_columns.append("peptide") 

209 return expected_columns 

210 

211 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: 

212 """Logic for collating a batch. 

213 

214 Args: 

215 batch (list[dict[str, Any]]): The batch to collate. 

216 

217 Returns: 

218 dict[str, Any]: The collated batch. 

219 """ 

220 data_batch = [ 

221 ( 

222 row["spectra"], 

223 row["precursor_mz"], 

224 row["precursor_charge"], 

225 ) 

226 for row in batch 

227 ] 

228 

229 spectra, precursor_mzs, precursor_charges = zip(*data_batch, strict=True) 

230 

231 # Pad spectra 

232 spectra, spectra_mask = DataProcessor._pad_and_mask(spectra) 

233 

234 precursor_mzs = torch.tensor(precursor_mzs) 

235 precursor_charges = torch.tensor(precursor_charges) 

236 precursor_masses = (precursor_mzs - PROTON_MASS_AMU) * precursor_charges 

237 precursors = torch.vstack([precursor_masses, precursor_charges, precursor_mzs]).T.float() 

238 

239 # Force all input data to be contiguous 

240 precursors = precursors.contiguous() 

241 spectra = spectra.contiguous() 

242 

243 return_batch = { 

244 "spectra": spectra, 

245 "precursors": precursors, 

246 "spectra_mask": spectra_mask, 

247 } 

248 

249 # Add peptide if annotated 

250 if self.annotated: 

251 peptides_batch = [row["peptide"] for row in batch] 

252 

253 # Pad peptide 

254 if not isinstance(peptides_batch[0], str): 

255 peptides, peptides_mask = self._pad_and_mask(peptides_batch) 

256 peptides = peptides.contiguous() 

257 else: 

258 peptides = peptides_batch 

259 peptides_mask = None 

260 

261 return_batch.update( 

262 { 

263 "peptides": peptides, 

264 "peptides_mask": peptides_mask, 

265 } 

266 ) 

267 

268 return return_batch