Coverage for instanovo/common/dataset.py: 88%
52 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import re
4from abc import ABCMeta, abstractmethod
5from typing import Any
7import torch
8from datasets import Dataset
9from torch import nn
12class DataProcessor(metaclass=ABCMeta):
13 """Data processor abstract class.
15 This class is used to process the data before it is used in the model.
16 It is designed to be used with the `Dataset` class from the HuggingFace `datasets` library.
18 It includes two main methods:
19 - `process_row`: Processes a row of data.
20 - `collate_fn`: Collates a batch of data. To be passed to the `DataLoader` class.
22 Additionally, it includes a way to pass metadata columns that will be kept after processing a dataset.
23 These metadata columns will also bypass the `collate_fn`.
24 """
26 @property
27 def metadata_columns(self) -> set[str]:
28 """Get the metadata columns.
30 These columns are kept after processing a dataset.
32 Returns:
33 list[str]: The metadata columns.
34 """
35 return self._metadata_columns
37 def __init__(self, metadata_columns: list[str] | set[str] | None = None):
38 """Initialize the data processor.
40 Args:
41 metadata_columns: The metadata columns to add to the expected columns.
42 """
43 self._metadata_columns: set[str] = set(metadata_columns or [])
45 @abstractmethod
46 def _get_expected_columns(self) -> list[str]:
47 """Get the expected columns.
49 These are the columns that will be returned by the `process_row` method.
51 Returns:
52 list[str]: The expected columns.
53 """
54 ...
56 @abstractmethod
57 def process_row(self, row: dict[str, Any]) -> dict[str, Any]:
58 """Process a single row of data.
60 Args:
61 row (dict[str, Any]): The row of data to process in dict format.
63 Returns:
64 dict[str, Any]: The processed row with resulting columns.
65 """
66 ...
68 @abstractmethod
69 def _collate_batch(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
70 """Logic for collating a batch.
72 Args:
73 batch (list[dict[str, Any]]): The batch to collate.
75 Returns:
76 dict[str, Any]: The collated batch.
77 """
78 ...
80 def process_dataset(self, dataset: Dataset, return_format: str | None = "torch") -> Dataset:
81 """Process a dataset by mapping the `process_row` method.
83 The resulting dataset has the columns expected by the `collate_fn` method.
85 Args:
86 dataset (Dataset): The dataset to process.
87 return_format (str | None): The format to return the dataset in.
88 Default is "torch".
90 Returns:
91 Dataset: The processed dataset.
92 """
93 dataset = dataset.map(self.process_row)
94 dataset.set_format(type=return_format, columns=self.get_expected_columns())
95 return dataset
97 def collate_fn(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
98 """Collate a batch.
100 Metadata columns are added after collation.
102 Args:
103 batch (list[dict[str, Any]]): The batch to collate.
105 Returns:
106 dict[str, Any]: The collated batch with metadata.
107 """
108 return_batch: dict[str, Any] = self._collate_batch(batch)
110 # Add metadata
111 metadata = {}
112 for col in self.metadata_columns:
113 if col in return_batch:
114 continue
115 metadata[col] = [row[col] if col in row else None for row in batch]
117 return_batch.update(metadata)
119 return return_batch
121 def get_expected_columns(self) -> list[str]:
122 """Get the expected columns to be kept in the dataset after processing.
124 These columns are expected by the `collate_fn` method and include
125 both data and metadata columns.
127 Returns:
128 list[str]: The expected columns.
129 """
130 return self._get_expected_columns() + list(self.metadata_columns)
132 def add_metadata_columns(self, columns: list[str] | set[str]) -> None:
133 """Add expected metadata columns.
135 Args:
136 columns (list[str] | set[str]): The columns to add.
137 """
138 self._metadata_columns.update(set(columns))
140 @staticmethod
141 def _pad_and_mask(
142 tensor_list: list[torch.tensor] | tuple[torch.tensor],
143 ) -> tuple[torch.tensor, torch.tensor]:
144 """Pad and mask a list of tensors.
146 Args:
147 tensor_list (list[torch.tensor] | tuple[torch.tensor]): The list of tensors to pad and mask.
149 Returns:
150 tuple[torch.tensor, torch.tensor]: The padded and masked tensors.
151 """
152 ll = torch.tensor([y.shape[0] for y in tensor_list], dtype=torch.long)
153 padded_tensor = nn.utils.rnn.pad_sequence(tensor_list, batch_first=True)
154 attention_mask = torch.arange(padded_tensor.shape[1], dtype=torch.long)[None, :] >= ll[:, None]
155 return padded_tensor, attention_mask
157 @staticmethod
158 def remove_modifications(peptide: str, replace_isoleucine_with_leucine: bool = True) -> str:
159 """Remove modifications and optionally replace Isoleucine with Leucine.
161 Args:
162 peptide (str): The peptide to remove modifications from.
163 replace_isoleucine_with_leucine (bool): Whether to replace Isoleucine with Leucine.
165 Returns:
166 str: The peptide with modifications removed.
167 """
168 # remove [UNIMOD as it will be picked up by the regex
169 peptide = peptide.replace("UNIMOD", "")
170 # use regex to extract only A-Z
171 peptide = re.findall(r"[A-Z]", peptide)
172 # replace I with L
173 if replace_isoleucine_with_leucine:
174 peptide = ["L" if aa == "I" else aa for aa in peptide]
175 return "".join(peptide)