Coverage for instanovo/transformer/train.py: 53%

77 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import os 

4import shutil 

5from pathlib import Path 

6from typing import Any 

7 

8import hydra 

9import torch 

10import torch.nn as nn 

11from omegaconf import DictConfig, OmegaConf 

12 

13from instanovo.__init__ import console 

14from instanovo.common import AccelerateDeNovoTrainer, DataProcessor 

15from instanovo.inference import Decoder, GreedyDecoder 

16from instanovo.transformer.data import TransformerDataProcessor 

17from instanovo.transformer.model import InstaNovo 

18from instanovo.utils.colorlogging import ColorLog 

19from instanovo.utils.s3 import S3FileHandler 

20 

21logger = ColorLog(console, __name__).logger 

22 

23CONFIG_PATH = Path(__file__).parent.parent / "configs" 

24 

25 

26class TransformerTrainer(AccelerateDeNovoTrainer): 

27 """Trainer for the InstaNovo model.""" 

28 

29 def __init__(self, config: DictConfig) -> None: 

30 super().__init__(config) 

31 

32 self.loss_fn = nn.CrossEntropyLoss(ignore_index=0) 

33 

34 def setup_model(self) -> nn.Module: 

35 """Setup the model.""" 

36 config = self.config.get("model", {}) 

37 model = InstaNovo( 

38 residue_set=self.residue_set, 

39 dim_model=config["dim_model"], 

40 n_head=config["n_head"], 

41 dim_feedforward=config["dim_feedforward"], 

42 encoder_layers=config.get("encoder_layers", config.get("n_layers", 9)), 

43 decoder_layers=config.get("decoder_layers", config.get("n_layers", 9)), 

44 dropout=config["dropout"], 

45 max_charge=config["max_charge"], 

46 use_flash_attention=config.get("use_flash_attention", False), 

47 conv_peak_encoder=config.get("conv_peak_encoder", False), 

48 peak_embedding_dtype="float32" if self.config.get("mps", False) else "float64", 

49 ) 

50 return model 

51 

52 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

53 """Update the vocabulary of the model.""" 

54 return self._update_vocab( # type: ignore 

55 model_state, 

56 target_layers=["head.weight", "head.bias", "aa_embed.weight"], 

57 resolution=self.config.get("residue_conflict_resolution", "delete"), 

58 ) 

59 

60 def setup_optimizer(self) -> torch.optim.Optimizer: 

61 """Setup the optimizer.""" 

62 return torch.optim.Adam( 

63 self.model.parameters(), 

64 lr=float(self.config["learning_rate"]), 

65 weight_decay=float(self.config.get("weight_decay", 0.0)), 

66 ) 

67 

68 def setup_decoder(self) -> Decoder: 

69 """Setup the decoder.""" 

70 return GreedyDecoder(model=self.model, float_dtype=torch.float32 if self.config.get("mps", False) else torch.float64) 

71 

72 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]: 

73 """Setup the datasets.""" 

74 train_processor = TransformerDataProcessor( 

75 self.residue_set, 

76 n_peaks=self.config.model.get("n_peaks", 200), 

77 min_mz=self.config.model.get("min_mz", 50.0), 

78 max_mz=self.config.model.get("max_mz", 2500.0), 

79 min_intensity=self.config.model.get("min_intensity", 0.01), 

80 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0), 

81 return_str=False, 

82 use_spectrum_utils=False, 

83 ) 

84 

85 valid_processor = TransformerDataProcessor( 

86 self.residue_set, 

87 n_peaks=self.config.model.get("n_peaks", 200), 

88 min_mz=self.config.model.get("min_mz", 50.0), 

89 max_mz=self.config.model.get("max_mz", 2500.0), 

90 min_intensity=self.config.model.get("min_intensity", 0.01), 

91 remove_precursor_tol=self.config.model.get("remove_precursor_tol", 2.0), 

92 return_str=False, 

93 use_spectrum_utils=False, 

94 ) 

95 

96 return train_processor, valid_processor 

97 

98 def add_checkpoint_state(self) -> dict[str, Any]: 

99 """Add checkpoint state.""" 

100 return {} 

101 

102 def save_model(self, is_best_checkpoint: bool = False) -> None: 

103 """Save the model.""" 

104 if not self.accelerator.is_main_process: 

105 return 

106 

107 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints") 

108 os.makedirs(checkpoint_dir, exist_ok=True) 

109 

110 # Save model 

111 if self.config.get("keep_model_every_interval", False): 

112 model_path = os.path.join(checkpoint_dir, f"model_epoch_{self.epoch:02d}_step_{self.global_step + 1}.ckpt") 

113 else: 

114 model_path = os.path.join(checkpoint_dir, "model_latest.ckpt") 

115 if Path(model_path).exists() and Path(model_path).is_file(): 

116 Path(model_path).unlink() 

117 

118 unwrapped_model = self.accelerator.unwrap_model(self.model) 

119 

120 checkpoint_state = { 

121 "state_dict": unwrapped_model.state_dict(), 

122 "config": OmegaConf.to_container(self.config.model), 

123 "residues": self.residue_set.residue_masses, 

124 "epoch": self.epoch, 

125 "global_step": self.global_step + 1, 

126 } 

127 checkpoint_state.update(self.add_checkpoint_state()) 

128 

129 torch.save(checkpoint_state, model_path) 

130 logger.info(f"Saved model to {model_path}") 

131 

132 if S3FileHandler._aichor_enabled(): 

133 self.s3.upload(model_path, S3FileHandler.convert_to_s3_output(model_path)) 

134 

135 if is_best_checkpoint and self.accelerator.is_main_process: 

136 best_model_path = os.path.join(checkpoint_dir, "model_best.ckpt") 

137 if Path(best_model_path).exists() and Path(best_model_path).is_file(): 

138 Path(best_model_path).unlink() 

139 

140 shutil.copy(model_path, best_model_path) 

141 

142 if S3FileHandler._aichor_enabled(): 

143 self.s3.upload(best_model_path, S3FileHandler.convert_to_s3_output(best_model_path)) 

144 

145 logger.info(f"Saved checkpoint to {model_path}") 

146 

147 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 

148 """Forward pass for the model to calculate loss.""" 

149 preds = self.model( 

150 x=batch["spectra"], 

151 p=batch["precursors"], 

152 y=batch["peptides"], 

153 x_mask=batch["spectra_mask"], 

154 y_mask=batch["peptides_mask"], 

155 ) 

156 

157 preds = preds[:, :-1].reshape(-1, preds.shape[-1]) 

158 

159 loss = self.loss_fn(preds, batch["peptides"].flatten()) 

160 

161 return loss, {"loss": loss} 

162 

163 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]: 

164 """Get the predictions for a batch.""" 

165 # Greedy decoding 

166 batch_predictions = self.decoder.decode( 

167 spectra=batch["spectra"], 

168 precursors=batch["precursors"], 

169 beam_size=self.config.get("n_beams", 1), 

170 max_length=self.config.get("max_length", 40), 

171 ) 

172 

173 targets = [self.residue_set.decode(i, reverse=True) for i in batch["peptides"]] 

174 

175 return batch_predictions["predictions"], targets 

176 

177 

178@hydra.main(config_path=str(CONFIG_PATH), version_base=None, config_name="instanovo") 

179def main(config: DictConfig) -> None: 

180 """Train the model.""" 

181 logger.info("Initializing training.") 

182 trainer = TransformerTrainer(config) 

183 trainer.train() 

184 

185 

186if __name__ == "__main__": 

187 main()