Coverage for instanovo/common/dataset.py: 88%

52 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import re 

4from abc import ABCMeta, abstractmethod 

5from typing import Any 

6 

7import torch 

8from datasets import Dataset 

9from torch import nn 

10 

11 

12class DataProcessor(metaclass=ABCMeta): 

13 """Data processor abstract class. 

14 

15 This class is used to process the data before it is used in the model. 

16 It is designed to be used with the `Dataset` class from the HuggingFace `datasets` library. 

17 

18 It includes two main methods: 

19 - `process_row`: Processes a row of data. 

20 - `collate_fn`: Collates a batch of data. To be passed to the `DataLoader` class. 

21 

22 Additionally, it includes a way to pass metadata columns that will be kept after processing a dataset. 

23 These metadata columns will also bypass the `collate_fn`. 

24 """ 

25 

26 @property 

27 def metadata_columns(self) -> set[str]: 

28 """Get the metadata columns. 

29 

30 These columns are kept after processing a dataset. 

31 

32 Returns: 

33 list[str]: The metadata columns. 

34 """ 

35 return self._metadata_columns 

36 

37 def __init__(self, metadata_columns: list[str] | set[str] | None = None): 

38 """Initialize the data processor. 

39 

40 Args: 

41 metadata_columns: The metadata columns to add to the expected columns. 

42 """ 

43 self._metadata_columns: set[str] = set(metadata_columns or []) 

44 

45 @abstractmethod 

46 def _get_expected_columns(self) -> list[str]: 

47 """Get the expected columns. 

48 

49 These are the columns that will be returned by the `process_row` method. 

50 

51 Returns: 

52 list[str]: The expected columns. 

53 """ 

54 ... 

55 

56 @abstractmethod 

57 def process_row(self, row: dict[str, Any]) -> dict[str, Any]: 

58 """Process a single row of data. 

59 

60 Args: 

61 row (dict[str, Any]): The row of data to process in dict format. 

62 

63 Returns: 

64 dict[str, Any]: The processed row with resulting columns. 

65 """ 

66 ... 

67 

68 @abstractmethod 

69 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, Any]: 

70 """Logic for collating a batch. 

71 

72 Args: 

73 batch (list[dict[str, Any]]): The batch to collate. 

74 

75 Returns: 

76 dict[str, Any]: The collated batch. 

77 """ 

78 ... 

79 

80 def process_dataset(self, dataset: Dataset, return_format: str | None = "torch") -> Dataset: 

81 """Process a dataset by mapping the `process_row` method. 

82 

83 The resulting dataset has the columns expected by the `collate_fn` method. 

84 

85 Args: 

86 dataset (Dataset): The dataset to process. 

87 return_format (str | None): The format to return the dataset in. 

88 Default is "torch". 

89 

90 Returns: 

91 Dataset: The processed dataset. 

92 """ 

93 dataset = dataset.map(self.process_row) 

94 dataset.set_format(type=return_format, columns=self.get_expected_columns()) 

95 return dataset 

96 

97 def collate_fn(self, batch: list[dict[str, Any]]) -> dict[str, Any]: 

98 """Collate a batch. 

99 

100 Metadata columns are added after collation. 

101 

102 Args: 

103 batch (list[dict[str, Any]]): The batch to collate. 

104 

105 Returns: 

106 dict[str, Any]: The collated batch with metadata. 

107 """ 

108 return_batch: dict[str, Any] = self._collate_batch(batch) 

109 

110 # Add metadata 

111 metadata = {} 

112 for col in self.metadata_columns: 

113 if col in return_batch: 

114 continue 

115 metadata[col] = [row[col] if col in row else None for row in batch] 

116 

117 return_batch.update(metadata) 

118 

119 return return_batch 

120 

121 def get_expected_columns(self) -> list[str]: 

122 """Get the expected columns to be kept in the dataset after processing. 

123 

124 These columns are expected by the `collate_fn` method and include 

125 both data and metadata columns. 

126 

127 Returns: 

128 list[str]: The expected columns. 

129 """ 

130 return self._get_expected_columns() + list(self.metadata_columns) 

131 

132 def add_metadata_columns(self, columns: list[str] | set[str]) -> None: 

133 """Add expected metadata columns. 

134 

135 Args: 

136 columns (list[str] | set[str]): The columns to add. 

137 """ 

138 self._metadata_columns.update(set(columns)) 

139 

140 @staticmethod 

141 def _pad_and_mask( 

142 tensor_list: list[torch.tensor] | tuple[torch.tensor], 

143 ) -> tuple[torch.tensor, torch.tensor]: 

144 """Pad and mask a list of tensors. 

145 

146 Args: 

147 tensor_list (list[torch.tensor] | tuple[torch.tensor]): The list of tensors to pad and mask. 

148 

149 Returns: 

150 tuple[torch.tensor, torch.tensor]: The padded and masked tensors. 

151 """ 

152 ll = torch.tensor([y.shape[0] for y in tensor_list], dtype=torch.long) 

153 padded_tensor = nn.utils.rnn.pad_sequence(tensor_list, batch_first=True) 

154 attention_mask = torch.arange(padded_tensor.shape[1], dtype=torch.long)[None, :] >= ll[:, None] 

155 return padded_tensor, attention_mask 

156 

157 @staticmethod 

158 def remove_modifications(peptide: str, replace_isoleucine_with_leucine: bool = True) -> str: 

159 """Remove modifications and optionally replace Isoleucine with Leucine. 

160 

161 Args: 

162 peptide (str): The peptide to remove modifications from. 

163 replace_isoleucine_with_leucine (bool): Whether to replace Isoleucine with Leucine. 

164 

165 Returns: 

166 str: The peptide with modifications removed. 

167 """ 

168 # remove [UNIMOD as it will be picked up by the regex 

169 peptide = peptide.replace("UNIMOD", "") 

170 # use regex to extract only A-Z 

171 peptide = re.findall(r"[A-Z]", peptide) 

172 # replace I with L 

173 if replace_isoleucine_with_leucine: 

174 peptide = ["L" if aa == "I" else aa for aa in peptide] 

175 return "".join(peptide)