Coverage for instanovo/diffusion/data.py: 88%

51 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Dict 

4 

5import torch 

6 

7from instanovo.__init__ import console 

8from instanovo.constants import ANNOTATED_COLUMN, REFINEMENT_COLUMN, MSColumns 

9from instanovo.transformer.data import TransformerDataProcessor 

10from instanovo.utils.colorlogging import ColorLog 

11from instanovo.utils.residues import ResidueSet 

12 

13logger = ColorLog(console, __name__).logger 

14 

15 

16class DiffusionDataProcessor(TransformerDataProcessor): 

17 """Diffusion implementation of the DataProcessor class. 

18 

19 Includes methods to process spectra and peptides for diffusion de novo peptide sequencing. 

20 """ 

21 

22 def __init__( 

23 self, 

24 residue_set: ResidueSet, 

25 n_peaks: int = 200, 

26 min_mz: float = 50.0, 

27 max_mz: float = 2500.0, 

28 min_intensity: float = 0.01, 

29 remove_precursor_tol: float = 2.0, 

30 reverse_peptide: bool = True, 

31 annotated: bool = True, 

32 return_str: bool = False, 

33 add_eos: bool = True, 

34 use_spectrum_utils: bool = True, 

35 peptide_pad_length: int = 40, 

36 peptide_pad_value: int = 0, 

37 truncate_max_length: int | None = 40, 

38 metadata_columns: list[str] | None = None, 

39 ) -> None: 

40 """Initialize the data processor. 

41 

42 Args: 

43 residue_set (ResidueSet): The residue set to use. 

44 n_peaks (int): The number of peaks to keep in the spectrum. 

45 min_mz (float): The minimum m/z to keep in the spectrum. 

46 max_mz (float): The maximum m/z to keep in the spectrum. 

47 min_intensity (float): The minimum intensity to keep in the spectrum. 

48 remove_precursor_tol (float): The tolerance to remove the precursor peak in Da. 

49 reverse_peptide (bool): Whether to reverse the peptide. 

50 annotated (bool): Whether the dataset is annotated. 

51 return_str (bool): Whether to return the peptide as a string. 

52 add_eos (bool): Whether to add the end of sequence token. 

53 use_spectrum_utils (bool): Whether to use the spectrum_utils library to process the spectra. 

54 metadata_columns (list[str] | None): The metadata columns to add to the dataset. 

55 peptide_pad_length (int): The length to pad the peptide to. 

56 peptide_pad_value (int): The value to pad the peptide with. 

57 truncate_max_length (int | None): The maximum length to truncate the peptide to. 

58 """ 

59 self.peptide_pad_length = peptide_pad_length 

60 self.peptide_pad_value = peptide_pad_value 

61 self.truncate_max_length = truncate_max_length 

62 super().__init__( 

63 residue_set=residue_set, 

64 n_peaks=n_peaks, 

65 min_mz=min_mz, 

66 max_mz=max_mz, 

67 min_intensity=min_intensity, 

68 remove_precursor_tol=remove_precursor_tol, 

69 reverse_peptide=reverse_peptide, 

70 annotated=annotated, 

71 return_str=return_str, 

72 add_eos=add_eos, 

73 use_spectrum_utils=use_spectrum_utils, 

74 metadata_columns=metadata_columns, 

75 ) 

76 

77 def process_row(self, row: Dict[str, Any]) -> Dict[str, Any]: 

78 """Process a single row of data for diffusion de novo peptide sequencing. 

79 

80 Args: 

81 row (dict[str, Any]): The row of data to process in dict format. 

82 

83 Returns: 

84 dict[str, Any]: The processed row with resulting columns. 

85 """ 

86 processed = {} 

87 

88 # Spectra processing 

89 spectra = self._process_spectrum( 

90 torch.tensor(row[MSColumns.MZ_ARRAY.value]), 

91 torch.tensor(row[MSColumns.INTENSITY_ARRAY.value]), 

92 row[MSColumns.PRECURSOR_MZ.value], 

93 row[MSColumns.PRECURSOR_CHARGE.value], 

94 ) 

95 

96 processed["spectra"] = spectra 

97 

98 # Peptide processing 

99 if self.annotated: 

100 peptide = row[ANNOTATED_COLUMN] 

101 if not self.return_str: 

102 peptide_tokenized = self.residue_set.tokenize(peptide) 

103 

104 if self.reverse_peptide: 

105 peptide_tokenized = peptide_tokenized[::-1] 

106 

107 peptide_encoding = self.residue_set.encode(peptide_tokenized, add_eos=self.add_eos, return_tensor="pt") 

108 

109 if self.truncate_max_length: 

110 peptide_encoding = peptide_encoding[: self.truncate_max_length] 

111 

112 # Diffusion always padded to fixed length 

113 peptide_padded = torch.full( 

114 (max(self.peptide_pad_length, peptide_encoding.shape[0]),), 

115 fill_value=self.peptide_pad_value, 

116 dtype=peptide_encoding.dtype, 

117 device=peptide_encoding.device, 

118 ) 

119 peptide_padded[: peptide_encoding.shape[0]] = peptide_encoding 

120 

121 processed["peptide"] = peptide_padded 

122 else: 

123 processed["peptide"] = peptide 

124 

125 if REFINEMENT_COLUMN in row: 

126 refine = row[REFINEMENT_COLUMN] 

127 

128 refine_tokenized = self.residue_set.tokenize(refine) 

129 if self.reverse_peptide: 

130 refine_tokenized = refine_tokenized[::-1] 

131 

132 refine_encoding = self.residue_set.encode(refine_tokenized, add_eos=self.add_eos, return_tensor="pt") 

133 

134 if self.truncate_max_length: 

135 refine_encoding = refine_encoding[: self.truncate_max_length] 

136 

137 # Diffusion always padded to fixed length 

138 refine_padded = torch.full( 

139 (max(self.peptide_pad_length, refine_encoding.shape[0]),), 

140 fill_value=self.peptide_pad_value, 

141 dtype=refine_encoding.dtype, 

142 device=refine_encoding.device, 

143 ) 

144 refine_padded[: refine_encoding.shape[0]] = refine_encoding 

145 

146 processed[REFINEMENT_COLUMN] = refine_padded 

147 

148 return processed 

149 

150 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, Any]: 

151 """Logic for collating a batch. 

152 

153 Args: 

154 batch (list[dict[str, Any]]): The batch to collate. 

155 

156 Returns: 

157 dict[str, Any] | tuple[Any]: The collated batch. 

158 """ 

159 return_batch = super()._collate_batch(batch) 

160 

161 if REFINEMENT_COLUMN not in batch[0]: 

162 return return_batch # type: ignore 

163 

164 refinement_peptide = [row[REFINEMENT_COLUMN] for row in batch] 

165 

166 refinement_peptide, _ = self._pad_and_mask(refinement_peptide) 

167 

168 return { 

169 **return_batch, 

170 REFINEMENT_COLUMN: refinement_peptide, 

171 }