Coverage for instanovo/diffusion/data.py: 88%
51 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from typing import Any, Dict
5import torch
7from instanovo.__init__ import console
8from instanovo.constants import ANNOTATED_COLUMN, REFINEMENT_COLUMN, MSColumns
9from instanovo.transformer.data import TransformerDataProcessor
10from instanovo.utils.colorlogging import ColorLog
11from instanovo.utils.residues import ResidueSet
13logger = ColorLog(console, __name__).logger
16class DiffusionDataProcessor(TransformerDataProcessor):
17 """Diffusion implementation of the DataProcessor class.
19 Includes methods to process spectra and peptides for diffusion de novo peptide sequencing.
20 """
22 def __init__(
23 self,
24 residue_set: ResidueSet,
25 n_peaks: int = 200,
26 min_mz: float = 50.0,
27 max_mz: float = 2500.0,
28 min_intensity: float = 0.01,
29 remove_precursor_tol: float = 2.0,
30 reverse_peptide: bool = True,
31 annotated: bool = True,
32 return_str: bool = False,
33 add_eos: bool = True,
34 use_spectrum_utils: bool = True,
35 peptide_pad_length: int = 40,
36 peptide_pad_value: int = 0,
37 truncate_max_length: int | None = 40,
38 metadata_columns: list[str] | None = None,
39 ) -> None:
40 """Initialize the data processor.
42 Args:
43 residue_set (ResidueSet): The residue set to use.
44 n_peaks (int): The number of peaks to keep in the spectrum.
45 min_mz (float): The minimum m/z to keep in the spectrum.
46 max_mz (float): The maximum m/z to keep in the spectrum.
47 min_intensity (float): The minimum intensity to keep in the spectrum.
48 remove_precursor_tol (float): The tolerance to remove the precursor peak in Da.
49 reverse_peptide (bool): Whether to reverse the peptide.
50 annotated (bool): Whether the dataset is annotated.
51 return_str (bool): Whether to return the peptide as a string.
52 add_eos (bool): Whether to add the end of sequence token.
53 use_spectrum_utils (bool): Whether to use the spectrum_utils library to process the spectra.
54 metadata_columns (list[str] | None): The metadata columns to add to the dataset.
55 peptide_pad_length (int): The length to pad the peptide to.
56 peptide_pad_value (int): The value to pad the peptide with.
57 truncate_max_length (int | None): The maximum length to truncate the peptide to.
58 """
59 self.peptide_pad_length = peptide_pad_length
60 self.peptide_pad_value = peptide_pad_value
61 self.truncate_max_length = truncate_max_length
62 super().__init__(
63 residue_set=residue_set,
64 n_peaks=n_peaks,
65 min_mz=min_mz,
66 max_mz=max_mz,
67 min_intensity=min_intensity,
68 remove_precursor_tol=remove_precursor_tol,
69 reverse_peptide=reverse_peptide,
70 annotated=annotated,
71 return_str=return_str,
72 add_eos=add_eos,
73 use_spectrum_utils=use_spectrum_utils,
74 metadata_columns=metadata_columns,
75 )
77 def process_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
78 """Process a single row of data for diffusion de novo peptide sequencing.
80 Args:
81 row (dict[str, Any]): The row of data to process in dict format.
83 Returns:
84 dict[str, Any]: The processed row with resulting columns.
85 """
86 processed = {}
88 # Spectra processing
89 spectra = self._process_spectrum(
90 torch.tensor(row[MSColumns.MZ_ARRAY.value]),
91 torch.tensor(row[MSColumns.INTENSITY_ARRAY.value]),
92 row[MSColumns.PRECURSOR_MZ.value],
93 row[MSColumns.PRECURSOR_CHARGE.value],
94 )
96 processed["spectra"] = spectra
98 # Peptide processing
99 if self.annotated:
100 peptide = row[ANNOTATED_COLUMN]
101 if not self.return_str:
102 peptide_tokenized = self.residue_set.tokenize(peptide)
104 if self.reverse_peptide:
105 peptide_tokenized = peptide_tokenized[::-1]
107 peptide_encoding = self.residue_set.encode(peptide_tokenized, add_eos=self.add_eos, return_tensor="pt")
109 if self.truncate_max_length:
110 peptide_encoding = peptide_encoding[: self.truncate_max_length]
112 # Diffusion always padded to fixed length
113 peptide_padded = torch.full(
114 (max(self.peptide_pad_length, peptide_encoding.shape[0]),),
115 fill_value=self.peptide_pad_value,
116 dtype=peptide_encoding.dtype,
117 device=peptide_encoding.device,
118 )
119 peptide_padded[: peptide_encoding.shape[0]] = peptide_encoding
121 processed["peptide"] = peptide_padded
122 else:
123 processed["peptide"] = peptide
125 if REFINEMENT_COLUMN in row:
126 refine = row[REFINEMENT_COLUMN]
128 refine_tokenized = self.residue_set.tokenize(refine)
129 if self.reverse_peptide:
130 refine_tokenized = refine_tokenized[::-1]
132 refine_encoding = self.residue_set.encode(refine_tokenized, add_eos=self.add_eos, return_tensor="pt")
134 if self.truncate_max_length:
135 refine_encoding = refine_encoding[: self.truncate_max_length]
137 # Diffusion always padded to fixed length
138 refine_padded = torch.full(
139 (max(self.peptide_pad_length, refine_encoding.shape[0]),),
140 fill_value=self.peptide_pad_value,
141 dtype=refine_encoding.dtype,
142 device=refine_encoding.device,
143 )
144 refine_padded[: refine_encoding.shape[0]] = refine_encoding
146 processed[REFINEMENT_COLUMN] = refine_padded
148 return processed
150 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
151 """Logic for collating a batch.
153 Args:
154 batch (list[dict[str, Any]]): The batch to collate.
156 Returns:
157 dict[str, Any] | tuple[Any]: The collated batch.
158 """
159 return_batch = super()._collate_batch(batch)
161 if REFINEMENT_COLUMN not in batch[0]:
162 return return_batch # type: ignore
164 refinement_peptide = [row[REFINEMENT_COLUMN] for row in batch]
166 refinement_peptide, _ = self._pad_and_mask(refinement_peptide)
168 return {
169 **return_batch,
170 REFINEMENT_COLUMN: refinement_peptide,
171 }