Coverage for instanovo/transformer/train.py: 53%
77 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import os
4import shutil
5from pathlib import Path
6from typing import Any
8import hydra
9import torch
10import torch.nn as nn
11from omegaconf import DictConfig, OmegaConf
13from instanovo.__init__ import console
14from instanovo.common import AccelerateDeNovoTrainer, DataProcessor
15from instanovo.inference import Decoder, GreedyDecoder
16from instanovo.transformer.data import TransformerDataProcessor
17from instanovo.transformer.model import InstaNovo
18from instanovo.utils.colorlogging import ColorLog
19from instanovo.utils.s3 import S3FileHandler
21logger = ColorLog(console, __name__).logger
23CONFIG_PATH = Path(__file__).parent.parent / "configs"
26class TransformerTrainer(AccelerateDeNovoTrainer):
27 """Trainer for the InstaNovo model."""
29 def __init__(self, config: DictConfig) -> None:
30 super().__init__(config)
32 self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)
34 def setup_model(self) -> nn.Module:
35 """Setup the model."""
36 config = self.config.get("model", {})
37 model = InstaNovo(
38 residue_set=self.residue_set,
39 dim_model=config["dim_model"],
40 n_head=config["n_head"],
41 dim_feedforward=config["dim_feedforward"],
42 encoder_layers=config.get("encoder_layers", config.get("n_layers", 9)),
43 decoder_layers=config.get("decoder_layers", config.get("n_layers", 9)),
44 dropout=config["dropout"],
45 max_charge=config["max_charge"],
46 use_flash_attention=config.get("use_flash_attention", False),
47 conv_peak_encoder=config.get("conv_peak_encoder", False),
48 peak_embedding_dtype="float32" if self.config.get("mps", False) else "float64",
49 )
50 return model
52 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
53 """Update the vocabulary of the model."""
54 return self._update_vocab( # type: ignore
55 model_state,
56 target_layers=["head.weight", "head.bias", "aa_embed.weight"],
57 resolution=self.config.get("residue_conflict_resolution", "delete"),
58 )
60 def setup_optimizer(self) -> torch.optim.Optimizer:
61 """Setup the optimizer."""
62 return torch.optim.Adam(
63 self.model.parameters(),
64 lr=float(self.config["learning_rate"]),
65 weight_decay=float(self.config.get("weight_decay", 0.0)),
66 )
68 def setup_decoder(self) -> Decoder:
69 """Setup the decoder."""
70 return GreedyDecoder(model=self.model, float_dtype=torch.float32 if self.config.get("mps", False) else torch.float64)
72 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]:
73 """Setup the datasets."""
74 train_processor = TransformerDataProcessor(
75 self.residue_set,
76 n_peaks=self.config.model.get("n_peaks", 200),
77 min_mz=self.config.model.get("min_mz", 50.0),
78 max_mz=self.config.model.get("max_mz", 2500.0),
79 min_intensity=self.config.model.get("min_intensity", 0.01),
80 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0),
81 return_str=False,
82 use_spectrum_utils=False,
83 )
85 valid_processor = TransformerDataProcessor(
86 self.residue_set,
87 n_peaks=self.config.model.get("n_peaks", 200),
88 min_mz=self.config.model.get("min_mz", 50.0),
89 max_mz=self.config.model.get("max_mz", 2500.0),
90 min_intensity=self.config.model.get("min_intensity", 0.01),
91 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0),
92 return_str=False,
93 use_spectrum_utils=False,
94 )
96 return train_processor, valid_processor
98 def add_checkpoint_state(self) -> dict[str, Any]:
99 """Add checkpoint state."""
100 return {}
102 def save_model(self, is_best_checkpoint: bool = False) -> None:
103 """Save the model."""
104 if not self.accelerator.is_main_process:
105 return
107 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints")
108 os.makedirs(checkpoint_dir, exist_ok=True)
110 # Save model
111 if self.config.get("keep_model_every_interval", False):
112 model_path = os.path.join(checkpoint_dir, f"model_epoch_{self.epoch:02d}_step_{self.global_step + 1}.ckpt")
113 else:
114 model_path = os.path.join(checkpoint_dir, "model_latest.ckpt")
115 if Path(model_path).exists() and Path(model_path).is_file():
116 Path(model_path).unlink()
118 unwrapped_model = self.accelerator.unwrap_model(self.model)
120 checkpoint_state = {
121 "state_dict": unwrapped_model.state_dict(),
122 "config": OmegaConf.to_container(self.config.model),
123 "residues": self.residue_set.residue_masses,
124 "epoch": self.epoch,
125 "global_step": self.global_step + 1,
126 }
127 checkpoint_state.update(self.add_checkpoint_state())
129 torch.save(checkpoint_state, model_path)
130 logger.info(f"Saved model to {model_path}")
132 if S3FileHandler._aichor_enabled():
133 self.s3.upload(model_path, S3FileHandler.convert_to_s3_output(model_path))
135 if is_best_checkpoint and self.accelerator.is_main_process:
136 best_model_path = os.path.join(checkpoint_dir, "model_best.ckpt")
137 if Path(best_model_path).exists() and Path(best_model_path).is_file():
138 Path(best_model_path).unlink()
140 shutil.copy(model_path, best_model_path)
142 if S3FileHandler._aichor_enabled():
143 self.s3.upload(best_model_path, S3FileHandler.convert_to_s3_output(best_model_path))
145 logger.info(f"Saved checkpoint to {model_path}")
147 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
148 """Forward pass for the model to calculate loss."""
149 preds = self.model(
150 x=batch["spectra"],
151 p=batch["precursors"],
152 y=batch["peptides"],
153 x_mask=batch["spectra_mask"],
154 y_mask=batch["peptides_mask"],
155 )
157 preds = preds[:, :-1].reshape(-1, preds.shape[-1])
159 loss = self.loss_fn(preds, batch["peptides"].flatten())
161 return loss, {"loss": loss}
163 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]:
164 """Get the predictions for a batch."""
165 # Greedy decoding
166 batch_predictions = self.decoder.decode(
167 spectra=batch["spectra"],
168 precursors=batch["precursors"],
169 beam_size=self.config.get("n_beams", 1),
170 max_length=self.config.get("max_length", 40),
171 )
173 targets = [self.residue_set.decode(i, reverse=True) for i in batch["peptides"]]
175 return batch_predictions["predictions"], targets
178@hydra.main(config_path=str(CONFIG_PATH), version_base=None, config_name="instanovo")
179def main(config: DictConfig) -> None:
180 """Train the model."""
181 logger.info("Initializing training.")
182 trainer = TransformerTrainer(config)
183 trainer.train()
186if __name__ == "__main__":
187 main()