Coverage for instanovo/diffusion/train.py: 53%

85 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import os 

4import shutil 

5from pathlib import Path 

6from typing import Any 

7 

8import hydra 

9import torch 

10import torch.nn as nn 

11from omegaconf import DictConfig, OmegaConf 

12 

13from instanovo.__init__ import console 

14from instanovo.common import AccelerateDeNovoTrainer, DataProcessor 

15from instanovo.diffusion.data import DiffusionDataProcessor 

16from instanovo.diffusion.multinomial_diffusion import ( 

17 DiffusionLoss, 

18 InstaNovoPlus, 

19 MassSpectrumTransFusion, 

20 cosine_beta_schedule, 

21) 

22from instanovo.inference import Decoder 

23from instanovo.inference.diffusion import DiffusionDecoder 

24from instanovo.utils.colorlogging import ColorLog 

25from instanovo.utils.s3 import S3FileHandler 

26 

27logger = ColorLog(console, __name__).logger 

28 

29CONFIG_PATH = Path(__file__).parent.parent / "configs" 

30 

31 

32class DiffusionTrainer(AccelerateDeNovoTrainer): 

33 """Trainer for the InstaNovo model.""" 

34 

35 def __init__(self, config: DictConfig) -> None: 

36 super().__init__(config) 

37 

38 self.loss_fn = DiffusionLoss(model=self.model) 

39 

40 def setup_model(self) -> nn.Module: 

41 """Setup the model.""" 

42 config = self.config.get("model", {}) 

43 transition_model = MassSpectrumTransFusion( 

44 cfg=config, 

45 max_transcript_len=config["max_length"], 

46 ) 

47 diffusion_schedule = cosine_beta_schedule(timesteps=config["time_steps"]) 

48 model = InstaNovoPlus( 

49 config=config, 

50 transition_model=transition_model, 

51 diffusion_schedule=diffusion_schedule, 

52 residue_set=self.residue_set, 

53 ) 

54 return model 

55 

56 def setup_optimizer(self) -> torch.optim.Optimizer: 

57 """Setup the optimizer.""" 

58 return torch.optim.Adam( 

59 self.model.parameters(), 

60 lr=float(self.config["learning_rate"]), 

61 weight_decay=float(self.config.get("weight_decay", 0.0)), 

62 ) 

63 

64 def setup_decoder(self) -> Decoder: 

65 """Setup the decoder.""" 

66 # TODO: Make DiffusionDecoder conform to Decoder interface 

67 return DiffusionDecoder(model=self.model) # type: ignore 

68 

69 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]: 

70 """Setup the datasets.""" 

71 train_processor = DiffusionDataProcessor( 

72 self.residue_set, 

73 n_peaks=self.config.model.get("n_peaks", 200), 

74 min_mz=self.config.model.get("min_mz", 50.0), 

75 max_mz=self.config.model.get("max_mz", 2500.0), 

76 min_intensity=self.config.model.get("min_intensity", 0.01), 

77 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0), 

78 return_str=False, 

79 reverse_peptide=False, 

80 add_eos=False, 

81 use_spectrum_utils=False, 

82 peptide_pad_length=self.config.model.get("max_length", 40), 

83 peptide_pad_value=self.residue_set.PAD_INDEX, 

84 ) 

85 

86 valid_processor = DiffusionDataProcessor( 

87 self.residue_set, 

88 n_peaks=self.config.model.get("n_peaks", 200), 

89 min_mz=self.config.model.get("min_mz", 50.0), 

90 max_mz=self.config.model.get("max_mz", 2500.0), 

91 min_intensity=self.config.model.get("min_intensity", 0.01), 

92 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0), 

93 return_str=False, 

94 reverse_peptide=False, 

95 add_eos=False, 

96 use_spectrum_utils=False, 

97 peptide_pad_length=self.config.model.get("max_length", 40), 

98 peptide_pad_value=self.residue_set.PAD_INDEX, 

99 ) 

100 

101 return train_processor, valid_processor 

102 

103 def add_checkpoint_state(self) -> dict[str, Any]: 

104 """Add checkpoint state.""" 

105 return {} 

106 

107 def save_model(self, is_best_checkpoint: bool = False) -> None: 

108 """Save the model.""" 

109 if not self.accelerator.is_main_process: 

110 return 

111 

112 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints") 

113 os.makedirs(checkpoint_dir, exist_ok=True) 

114 

115 # Save model 

116 if self.config.get("keep_model_every_interval", False): 

117 model_path = os.path.join(checkpoint_dir, f"model_epoch_{self.epoch:02d}_step_{self.global_step + 1}.ckpt") 

118 else: 

119 model_path = os.path.join(checkpoint_dir, "model_latest.ckpt") 

120 if Path(model_path).exists() and Path(model_path).is_file(): 

121 Path(model_path).unlink() 

122 

123 unwrapped_model = self.accelerator.unwrap_model(self.model) 

124 

125 checkpoint_state = { 

126 "state_dict": unwrapped_model.state_dict(), 

127 "diffusion_schedule": torch.exp(unwrapped_model.diffusion_schedule).tolist(), 

128 "config": OmegaConf.to_container(self.config.model), 

129 "residues": self.residue_set.residue_masses, 

130 "epoch": self.epoch, 

131 "global_step": self.global_step + 1, 

132 } 

133 checkpoint_state.update(self.add_checkpoint_state()) 

134 

135 torch.save(checkpoint_state, model_path) 

136 logger.info(f"Saved model to {model_path}") 

137 

138 if S3FileHandler._aichor_enabled(): 

139 self.s3.upload(model_path, S3FileHandler.convert_to_s3_output(model_path)) 

140 

141 if is_best_checkpoint and self.accelerator.is_main_process: 

142 best_model_path = os.path.join(checkpoint_dir, "model_best.ckpt") 

143 if Path(best_model_path).exists() and Path(best_model_path).is_file(): 

144 Path(best_model_path).unlink() 

145 

146 shutil.copy(model_path, best_model_path) 

147 

148 if S3FileHandler._aichor_enabled(): 

149 self.s3.upload(best_model_path, S3FileHandler.convert_to_s3_output(best_model_path)) 

150 

151 logger.info(f"Saved checkpoint to {model_path}") 

152 

153 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

154 """Update the vocabulary of the model.""" 

155 return self._update_vocab( # type: ignore 

156 model_state, 

157 target_layers=[ 

158 "transition_model.head.1.weight", 

159 "transition_model.head.1.bias", 

160 "transition_model.char_embedding.weight", 

161 ], 

162 resolution=self.config.get("residue_conflict_resolution", "delete"), 

163 ) 

164 

165 def update_model_state(self, model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]: 

166 """Update the model state.""" 

167 if model_config.get("time_steps", 200) != self.config.model.get("time_steps", 200): 

168 logger.warning("Time steps do not match. Updating model state.") 

169 for param in [ 

170 "diffusion_schedule", 

171 "diffusion_schedule_complement", 

172 "cumulative_schedule", 

173 "cumulative_schedule_complement", 

174 ]: 

175 if param in model_state: 

176 del model_state[param] 

177 

178 return model_state 

179 

180 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 

181 """Forward pass for the model to calculate loss.""" 

182 loss = self.loss_fn( 

183 batch["peptides"], 

184 spectra=batch["spectra"], 

185 spectra_padding_mask=batch["spectra_mask"], 

186 precursors=batch["precursors"], 

187 x_padding_mask=batch["peptides_mask"], 

188 ) 

189 

190 return loss, {"loss": loss} 

191 

192 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]: 

193 """Get the predictions for a batch.""" 

194 # Greedy decoding 

195 batch_predictions = self.decoder.decode( 

196 spectra=batch["spectra"], 

197 spectra_padding_mask=batch["spectra_mask"], 

198 precursors=batch["precursors"], 

199 ) 

200 

201 targets = [self.residue_set.decode(seq, reverse=False) for seq in batch["peptides"]] 

202 

203 return batch_predictions["predictions"], targets 

204 

205 

206@hydra.main(config_path=str(CONFIG_PATH), version_base=None, config_name="instanovo") 

207def main(config: DictConfig) -> None: 

208 """Train the model.""" 

209 logger.info("Initializing training.") 

210 trainer = DiffusionTrainer(config) 

211 trainer.train() 

212 

213 

214if __name__ == "__main__": 

215 main()