Coverage for instanovo/diffusion/train.py: 53%
85 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import os
4import shutil
5from pathlib import Path
6from typing import Any
8import hydra
9import torch
10import torch.nn as nn
11from omegaconf import DictConfig, OmegaConf
13from instanovo.__init__ import console
14from instanovo.common import AccelerateDeNovoTrainer, DataProcessor
15from instanovo.diffusion.data import DiffusionDataProcessor
16from instanovo.diffusion.multinomial_diffusion import (
17 DiffusionLoss,
18 InstaNovoPlus,
19 MassSpectrumTransFusion,
20 cosine_beta_schedule,
21)
22from instanovo.inference import Decoder
23from instanovo.inference.diffusion import DiffusionDecoder
24from instanovo.utils.colorlogging import ColorLog
25from instanovo.utils.s3 import S3FileHandler
27logger = ColorLog(console, __name__).logger
29CONFIG_PATH = Path(__file__).parent.parent / "configs"
32class DiffusionTrainer(AccelerateDeNovoTrainer):
33 """Trainer for the InstaNovo model."""
35 def __init__(self, config: DictConfig) -> None:
36 super().__init__(config)
38 self.loss_fn = DiffusionLoss(model=self.model)
40 def setup_model(self) -> nn.Module:
41 """Setup the model."""
42 config = self.config.get("model", {})
43 transition_model = MassSpectrumTransFusion(
44 cfg=config,
45 max_transcript_len=config["max_length"],
46 )
47 diffusion_schedule = cosine_beta_schedule(timesteps=config["time_steps"])
48 model = InstaNovoPlus(
49 config=config,
50 transition_model=transition_model,
51 diffusion_schedule=diffusion_schedule,
52 residue_set=self.residue_set,
53 )
54 return model
56 def setup_optimizer(self) -> torch.optim.Optimizer:
57 """Setup the optimizer."""
58 return torch.optim.Adam(
59 self.model.parameters(),
60 lr=float(self.config["learning_rate"]),
61 weight_decay=float(self.config.get("weight_decay", 0.0)),
62 )
64 def setup_decoder(self) -> Decoder:
65 """Setup the decoder."""
66 # TODO: Make DiffusionDecoder conform to Decoder interface
67 return DiffusionDecoder(model=self.model) # type: ignore
69 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]:
70 """Setup the datasets."""
71 train_processor = DiffusionDataProcessor(
72 self.residue_set,
73 n_peaks=self.config.model.get("n_peaks", 200),
74 min_mz=self.config.model.get("min_mz", 50.0),
75 max_mz=self.config.model.get("max_mz", 2500.0),
76 min_intensity=self.config.model.get("min_intensity", 0.01),
77 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0),
78 return_str=False,
79 reverse_peptide=False,
80 add_eos=False,
81 use_spectrum_utils=False,
82 peptide_pad_length=self.config.model.get("max_length", 40),
83 peptide_pad_value=self.residue_set.PAD_INDEX,
84 )
86 valid_processor = DiffusionDataProcessor(
87 self.residue_set,
88 n_peaks=self.config.model.get("n_peaks", 200),
89 min_mz=self.config.model.get("min_mz", 50.0),
90 max_mz=self.config.model.get("max_mz", 2500.0),
91 min_intensity=self.config.model.get("min_intensity", 0.01),
92 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0),
93 return_str=False,
94 reverse_peptide=False,
95 add_eos=False,
96 use_spectrum_utils=False,
97 peptide_pad_length=self.config.model.get("max_length", 40),
98 peptide_pad_value=self.residue_set.PAD_INDEX,
99 )
101 return train_processor, valid_processor
103 def add_checkpoint_state(self) -> dict[str, Any]:
104 """Add checkpoint state."""
105 return {}
107 def save_model(self, is_best_checkpoint: bool = False) -> None:
108 """Save the model."""
109 if not self.accelerator.is_main_process:
110 return
112 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints")
113 os.makedirs(checkpoint_dir, exist_ok=True)
115 # Save model
116 if self.config.get("keep_model_every_interval", False):
117 model_path = os.path.join(checkpoint_dir, f"model_epoch_{self.epoch:02d}_step_{self.global_step + 1}.ckpt")
118 else:
119 model_path = os.path.join(checkpoint_dir, "model_latest.ckpt")
120 if Path(model_path).exists() and Path(model_path).is_file():
121 Path(model_path).unlink()
123 unwrapped_model = self.accelerator.unwrap_model(self.model)
125 checkpoint_state = {
126 "state_dict": unwrapped_model.state_dict(),
127 "diffusion_schedule": torch.exp(unwrapped_model.diffusion_schedule).tolist(),
128 "config": OmegaConf.to_container(self.config.model),
129 "residues": self.residue_set.residue_masses,
130 "epoch": self.epoch,
131 "global_step": self.global_step + 1,
132 }
133 checkpoint_state.update(self.add_checkpoint_state())
135 torch.save(checkpoint_state, model_path)
136 logger.info(f"Saved model to {model_path}")
138 if S3FileHandler._aichor_enabled():
139 self.s3.upload(model_path, S3FileHandler.convert_to_s3_output(model_path))
141 if is_best_checkpoint and self.accelerator.is_main_process:
142 best_model_path = os.path.join(checkpoint_dir, "model_best.ckpt")
143 if Path(best_model_path).exists() and Path(best_model_path).is_file():
144 Path(best_model_path).unlink()
146 shutil.copy(model_path, best_model_path)
148 if S3FileHandler._aichor_enabled():
149 self.s3.upload(best_model_path, S3FileHandler.convert_to_s3_output(best_model_path))
151 logger.info(f"Saved checkpoint to {model_path}")
153 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
154 """Update the vocabulary of the model."""
155 return self._update_vocab( # type: ignore
156 model_state,
157 target_layers=[
158 "transition_model.head.1.weight",
159 "transition_model.head.1.bias",
160 "transition_model.char_embedding.weight",
161 ],
162 resolution=self.config.get("residue_conflict_resolution", "delete"),
163 )
165 def update_model_state(self, model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]:
166 """Update the model state."""
167 if model_config.get("time_steps", 200) != self.config.model.get("time_steps", 200):
168 logger.warning("Time steps do not match. Updating model state.")
169 for param in [
170 "diffusion_schedule",
171 "diffusion_schedule_complement",
172 "cumulative_schedule",
173 "cumulative_schedule_complement",
174 ]:
175 if param in model_state:
176 del model_state[param]
178 return model_state
180 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
181 """Forward pass for the model to calculate loss."""
182 loss = self.loss_fn(
183 batch["peptides"],
184 spectra=batch["spectra"],
185 spectra_padding_mask=batch["spectra_mask"],
186 precursors=batch["precursors"],
187 x_padding_mask=batch["peptides_mask"],
188 )
190 return loss, {"loss": loss}
192 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]:
193 """Get the predictions for a batch."""
194 # Greedy decoding
195 batch_predictions = self.decoder.decode(
196 spectra=batch["spectra"],
197 spectra_padding_mask=batch["spectra_mask"],
198 precursors=batch["precursors"],
199 )
201 targets = [self.residue_set.decode(seq, reverse=False) for seq in batch["peptides"]]
203 return batch_predictions["predictions"], targets
206@hydra.main(config_path=str(CONFIG_PATH), version_base=None, config_name="instanovo")
207def main(config: DictConfig) -> None:
208 """Train the model."""
209 logger.info("Initializing training.")
210 trainer = DiffusionTrainer(config)
211 trainer.train()
214if __name__ == "__main__":
215 main()