Coverage for instanovo/transformer/predict.py: 90%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import os 

2from pathlib import Path 

3from typing import Any, Tuple 

4 

5import torch 

6import torch.nn as nn 

7from omegaconf import DictConfig 

8 

9from instanovo.__init__ import console 

10from instanovo.common import AccelerateDeNovoPredictor, DataProcessor 

11from instanovo.constants import MASS_SCALE, MAX_MASS 

12from instanovo.inference import ( 

13 BeamSearchDecoder, 

14 Decoder, 

15 GreedyDecoder, 

16 Knapsack, 

17 KnapsackBeamSearchDecoder, 

18) 

19from instanovo.transformer.data import TransformerDataProcessor 

20from instanovo.transformer.model import InstaNovo 

21from instanovo.utils.colorlogging import ColorLog 

22 

23logger = ColorLog(console, __name__).logger 

24 

25CONFIG_PATH = Path(__file__).parent.parent / "configs" 

26 

27 

28class TransformerPredictor(AccelerateDeNovoPredictor): 

29 """Predictor for the InstaNovo model.""" 

30 

31 def __init__( 

32 self, 

33 config: DictConfig, 

34 ) -> None: 

35 self.num_beams = config.get("num_beams", 1) 

36 # Change this to a generic save all outputs function 

37 self.save_beams = config.get("save_beams", False) 

38 super().__init__(config) 

39 

40 def load_model(self) -> Tuple[nn.Module, DictConfig]: 

41 """Setup the model.""" 

42 default_model = InstaNovo.get_pretrained()[0] 

43 model_path = self.config.get("instanovo_model", default_model) 

44 

45 logger.info(f"Loading InstaNovo model {model_path}") 

46 if model_path in InstaNovo.get_pretrained(): 

47 # Using a pretrained model from models.json 

48 model, model_config = InstaNovo.from_pretrained( 

49 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None 

50 ) 

51 else: 

52 model_path = self.s3.get_local_path(model_path) 

53 assert model_path is not None 

54 model, model_config = InstaNovo.load( 

55 model_path, override_config={"peak_embedding_dtype": "float32"} if self.config.get("mps", False) else None 

56 ) 

57 

58 return model, model_config 

59 

60 def setup_data_processor(self) -> DataProcessor: 

61 """Setup the data processor.""" 

62 processor = TransformerDataProcessor( 

63 self.residue_set, 

64 n_peaks=self.model_config.get("n_peaks", 200), 

65 min_mz=self.model_config.get("min_mz", 50.0), 

66 max_mz=self.model_config.get("max_mz", 2500.0), 

67 min_intensity=self.model_config.get("min_intensity", 0.01), 

68 remove_precursor_tol=self.model_config.get("remove_precursor_tol", 2.0), 

69 return_str=False, 

70 use_spectrum_utils=False, 

71 annotated=not self.denovo, 

72 metadata_columns=["group"], 

73 ) 

74 

75 return processor 

76 

77 def setup_decoder(self) -> Decoder: 

78 """Setup the decoder.""" 

79 float_dtype = torch.float32 if self.config.get("force_fp32", False) else torch.float64 

80 if self.config.get("use_knapsack", False): 

81 logger.info(f"Using Knapsack Beam Search with {self.num_beams} beam(s)") 

82 knapsack_path = self.config.get("knapsack_path", None) 

83 if knapsack_path is None or not os.path.exists(knapsack_path): 

84 logger.info("Knapsack path missing or not specified, generating...") 

85 knapsack = _setup_knapsack(self.model, self.config.get("max_isotope_error", 1)) 

86 decoder: Decoder = KnapsackBeamSearchDecoder(self.model, knapsack, float_dtype=float_dtype) 

87 if knapsack_path is not None: 

88 logger.info(f"Saving knapsack to {knapsack_path}") 

89 knapsack.save(knapsack_path) 

90 else: 

91 logger.info("Knapsack path found. Loading...") 

92 decoder = KnapsackBeamSearchDecoder.from_file(self.model, knapsack_path, float_dtype=float_dtype) 

93 elif self.num_beams > 1: 

94 logger.info(f"Using Beam Search with {self.num_beams} beam(s)") 

95 decoder = BeamSearchDecoder(self.model, float_dtype=float_dtype) 

96 else: 

97 logger.info(f"Using Greedy Search with {self.num_beams} beam(s)") 

98 decoder = GreedyDecoder( 

99 model=self.model, 

100 suppressed_residues=self.config.get("suppressed_residues", None), 

101 disable_terminal_residues_anywhere=self.config.get("disable_terminal_residues_anywhere", True), 

102 float_dtype=float_dtype, 

103 ) 

104 return decoder 

105 

106 def get_predictions(self, batch: Any) -> dict[str, Any]: 

107 """Get the predictions for a batch.""" 

108 batch_size = batch["spectra"].size(0) 

109 

110 batch_predictions: dict[str, Any] = self.decoder.decode( 

111 spectra=batch["spectra"], 

112 precursors=batch["precursors"], 

113 beam_size=self.num_beams, 

114 max_length=self.config.get("max_length", 40), 

115 return_beam=self.num_beams > 1, 

116 return_encoder_output=self.save_encoder_outputs, 

117 encoder_output_reduction=self.encoder_output_reduction, 

118 ) 

119 

120 if "peptides" in batch: 

121 targets = [self.residue_set.decode(seq, reverse=True) for seq in batch["peptides"]] 

122 else: 

123 targets = [None] * batch_size 

124 

125 batch_predictions["targets"] = targets 

126 

127 return batch_predictions 

128 

129 

130def _setup_knapsack(model: InstaNovo, max_isotope: int = 2) -> Knapsack: 

131 residue_masses = dict(model.residue_set.residue_masses.copy()) 

132 for special_residue in list(model.residue_set.residue_to_index.keys())[:3]: 

133 residue_masses[special_residue] = 0 

134 residue_indices = model.residue_set.residue_to_index 

135 return Knapsack.construct_knapsack( 

136 residue_masses=residue_masses, 

137 residue_indices=residue_indices, 

138 max_mass=MAX_MASS, 

139 mass_scale=MASS_SCALE, 

140 max_isotope=max_isotope, 

141 )