Coverage for instanovo/transformer/data.py: 88%
111 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Dict
5import numpy as np
6import spectrum_utils.spectrum as sus
7import torch
8from jaxtyping import Float
9from torch import Tensor
11from instanovo.__init__ import console
12from instanovo.common import DataProcessor
13from instanovo.constants import ANNOTATED_COLUMN, PROTON_MASS_AMU, MSColumns
14from instanovo.utils.colorlogging import ColorLog
15from instanovo.utils.residues import ResidueSet
17logger = ColorLog(console, __name__).logger
20class TransformerDataProcessor(DataProcessor):
21 """Transformer implementation of theDataProcessor class.
23 Includes methods to process spectra and peptides for auto-regressive
24 de novo peptide sequencing.
25 """
27 def __init__(
28 self,
29 residue_set: ResidueSet,
30 n_peaks: int = 200,
31 min_mz: float = 50.0,
32 max_mz: float = 2500.0,
33 min_intensity: float = 0.01,
34 remove_precursor_tol: float = 2.0,
35 reverse_peptide: bool = True,
36 annotated: bool = True,
37 return_str: bool = False,
38 add_eos: bool = True,
39 use_spectrum_utils: bool = True,
40 metadata_columns: list[str] | None = None,
41 ) -> None:
42 """Initialize the data processor.
44 Args:
45 residue_set (ResidueSet): The residue set to use.
46 n_peaks (int): The number of peaks to keep in the spectrum.
47 min_mz (float): The minimum m/z to keep in the spectrum.
48 max_mz (float): The maximum m/z to keep in the spectrum.
49 min_intensity (float): The minimum intensity to keep in the spectrum.
50 remove_precursor_tol (float): The tolerance to remove the precursor peak in Da.
51 reverse_peptide (bool): Whether to reverse the peptide.
52 annotated (bool): Whether the dataset is annotated.
53 return_str (bool): Whether to return the peptide as a string.
54 add_eos (bool): Whether to add the end of sequence token.
55 use_spectrum_utils (bool): Whether to use the spectrum_utils library to process the spectra.
56 metadata_columns (list[str] | None): The metadata columns to add to the dataset.
57 """
58 super().__init__(metadata_columns=metadata_columns)
59 self.residue_set = residue_set
60 self.n_peaks = n_peaks
61 self.min_mz = min_mz
62 self.max_mz = max_mz
63 self.min_intensity = min_intensity
64 self.remove_precursor_tol = remove_precursor_tol
65 self.reverse_peptide = reverse_peptide
66 self.annotated = annotated
67 self.return_str = return_str
68 self.add_eos = add_eos
69 self.use_spectrum_utils = use_spectrum_utils
71 def _process_spectrum(
72 self,
73 mz_array: Float[Tensor, " peak"],
74 int_array: Float[Tensor, " peak"],
75 precursor_mz: float,
76 precursor_charge: int,
77 ) -> torch.tensor:
78 """Process a single spectrum.
80 Args:
81 mz_array (Float[Tensor, " peak"]): The m/z array of the spectrum.
82 int_array (Float[Tensor, " peak"]): The intensity array of the spectrum.
83 precursor_mz (float): The precursor m/z.
84 precursor_charge (int): The precursor charge.
86 Returns:
87 torch.tensor: The processed spectrum.
88 """
89 if self.use_spectrum_utils:
90 spectrum = sus.MsmsSpectrum(
91 "",
92 precursor_mz,
93 precursor_charge,
94 np.asarray(mz_array).astype(np.float32),
95 np.asarray(int_array).astype(np.float32),
96 )
97 try:
98 spectrum.set_mz_range(self.min_mz, self.max_mz)
99 if len(spectrum.mz) == 0:
100 raise ValueError
101 spectrum.remove_precursor_peak(self.remove_precursor_tol, "Da")
102 if len(spectrum.mz) == 0:
103 raise ValueError
104 spectrum.filter_intensity(self.min_intensity, self.n_peaks)
105 if len(spectrum.mz) == 0:
106 raise ValueError
107 spectrum.scale_intensity("root", 1)
108 intensities = spectrum.intensity / np.linalg.norm(spectrum.intensity)
109 return torch.tensor(np.asarray([spectrum.mz, intensities])).T.float()
110 except ValueError:
111 # Replace invalid spectra by a dummy spectrum.
112 return torch.tensor([[0, 1]]).float()
114 # Fallback implementation that matches spectrum_utils functionality
115 try:
116 # 1. Set m/z range
117 mask = (mz_array >= self.min_mz) & (mz_array <= self.max_mz)
118 mz_array = mz_array[mask]
119 int_array = int_array[mask]
121 if len(mz_array) == 0:
122 raise ValueError
124 # 2. Remove precursor peak
125 precursor_mask = torch.abs(mz_array - precursor_mz) > self.remove_precursor_tol
126 mz_array = mz_array[precursor_mask]
127 int_array = int_array[precursor_mask]
129 if len(mz_array) == 0:
130 raise ValueError
132 # 3. Filter by intensity and keep top n_peaks
133 intensity_mask = int_array >= self.min_intensity
134 mz_array = mz_array[intensity_mask]
135 int_array = int_array[intensity_mask]
137 if len(mz_array) == 0:
138 raise ValueError
140 # Get top n_peaks by intensity
141 if len(mz_array) > self.n_peaks:
142 _, indices = torch.topk(int_array, self.n_peaks)
143 mz_array = mz_array[indices]
144 int_array = int_array[indices]
146 # 4. Scale intensity (root scaling)
147 int_array = torch.sqrt(int_array)
149 # 5. Normalize intensities
150 int_array = int_array / torch.linalg.norm(int_array)
152 return torch.stack([mz_array, int_array], dim=1).float()
154 except ValueError:
155 # Replace invalid spectra by a dummy spectrum
156 return torch.tensor([[0.0, 1.0]], dtype=torch.float32)
158 def process_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
159 """Process a single row of data for auto-regressive de novo peptide sequencing.
161 Args:
162 row (dict[str, Any]): The row of data to process in dict format.
164 Returns:
165 dict[str, Any]: The processed row with resulting columns.
166 """
167 processed = {}
169 # Spectra processing
170 spectra = self._process_spectrum(
171 torch.tensor(row[MSColumns.MZ_ARRAY.value]),
172 torch.tensor(row[MSColumns.INTENSITY_ARRAY.value]),
173 row[MSColumns.PRECURSOR_MZ.value],
174 row[MSColumns.PRECURSOR_CHARGE.value],
175 )
177 processed["spectra"] = spectra
179 # Peptide processing
180 if self.annotated:
181 if ANNOTATED_COLUMN not in row:
182 raise KeyError(f"Annotated column {ANNOTATED_COLUMN} not found in dataset.")
183 peptide = row[ANNOTATED_COLUMN]
184 if not self.return_str:
185 peptide_tokenized = self.residue_set.tokenize(peptide)
187 if self.reverse_peptide:
188 peptide_tokenized = peptide_tokenized[::-1]
190 peptide_encoding = self.residue_set.encode(peptide_tokenized, add_eos=self.add_eos, return_tensor="pt")
192 processed["peptide"] = peptide_encoding
193 else:
194 processed["peptide"] = peptide
196 return processed
198 def _get_expected_columns(self) -> list[str]:
199 """Get the expected columns.
201 These are the columns that will be returned by the `process_row` method.
203 Returns:
204 list[str]: The expected columns.
205 """
206 expected_columns = ["spectra", "precursor_mz", "precursor_charge"]
207 if self.annotated:
208 expected_columns.append("peptide")
209 return expected_columns
211 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]:
212 """Logic for collating a batch.
214 Args:
215 batch (list[dict[str, Any]]): The batch to collate.
217 Returns:
218 dict[str, Any]: The collated batch.
219 """
220 data_batch = [
221 (
222 row["spectra"],
223 row["precursor_mz"],
224 row["precursor_charge"],
225 )
226 for row in batch
227 ]
229 spectra, precursor_mzs, precursor_charges = zip(*data_batch, strict=True)
231 # Pad spectra
232 spectra, spectra_mask = DataProcessor._pad_and_mask(spectra)
234 precursor_mzs = torch.tensor(precursor_mzs)
235 precursor_charges = torch.tensor(precursor_charges)
236 precursor_masses = (precursor_mzs - PROTON_MASS_AMU) * precursor_charges
237 precursors = torch.vstack([precursor_masses, precursor_charges, precursor_mzs]).T.float()
239 # Force all input data to be contiguous
240 precursors = precursors.contiguous()
241 spectra = spectra.contiguous()
243 return_batch = {
244 "spectra": spectra,
245 "precursors": precursors,
246 "spectra_mask": spectra_mask,
247 }
249 # Add peptide if annotated
250 if self.annotated:
251 peptides_batch = [row["peptide"] for row in batch]
253 # Pad peptide
254 if not isinstance(peptides_batch[0], str):
255 peptides, peptides_mask = self._pad_and_mask(peptides_batch)
256 peptides = peptides.contiguous()
257 else:
258 peptides = peptides_batch
259 peptides_mask = None
261 return_batch.update(
262 {
263 "peptides": peptides,
264 "peptides_mask": peptides_mask,
265 }
266 )
268 return return_batch