Coverage for instanovo/inference/diffusion.py: 74%

101 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3from typing import Any, Literal, Optional 

4 

5import torch 

6from jaxtyping import Bool, Float, Integer 

7from torch.distributions import Categorical 

8 

9from instanovo.constants import CARBON_MASS_DELTA, DIFFUSION_EVAL_STEPS, DIFFUSION_START_STEP, H2O_MASS, PrecursorDimension 

10from instanovo.diffusion.multinomial_diffusion import DiffusionLoss, InstaNovoPlus 

11from instanovo.inference.interfaces import Decoder 

12from instanovo.types import Peptide, PrecursorFeatures, Spectrum, SpectrumMask 

13 

14 

15class DiffusionDecoder(Decoder): 

16 """Class for decoding from a diffusion model by forward sampling.""" 

17 

18 def __init__(self, model: InstaNovoPlus) -> None: 

19 super().__init__(model=model) 

20 # Override base class type annotation - this is actually InstaNovoPlus, not just Decodable 

21 self.model: InstaNovoPlus = model 

22 

23 self.time_steps = self.model.time_steps 

24 self.residue_set = self.model.residue_set 

25 self.loss_function = DiffusionLoss(model=self.model) 

26 

27 def decode( 

28 self, 

29 spectra: Float[Spectrum, " batch"], 

30 precursors: Float[PrecursorFeatures, " batch"], 

31 spectra_padding_mask: Bool[SpectrumMask, " batch"], 

32 initial_sequence: Optional[Integer[Peptide, " batch"]] = None, 

33 start_step: int = DIFFUSION_START_STEP, 

34 eval_steps: tuple[int, ...] = DIFFUSION_EVAL_STEPS, 

35 beam_size: int = 1, 

36 mass_tolerance: float = 5e-5, 

37 max_isotope: int = 1, 

38 return_encoder_output: bool = False, 

39 encoder_output_reduction: Literal["mean", "max", "sum", "full"] = "mean", 

40 return_beam: bool = False, 

41 **kwargs: Any, 

42 ) -> dict[str, Any]: 

43 """Decoding predictions from a diffusion model by forward sampling. 

44 

45 Args: 

46 spectra: 

47 A batch of spectra to be decoded. 

48 

49 spectra_padding_mask: 

50 Padding mask for a batch of variable length spectra. 

51 

52 precursors: 

53 Precursor mass, charge and m/z for a batch of spectra. 

54 

55 initial_sequence: 

56 An initial sequence for the model to refine. If no initial sequence is 

57 provided (the value is None), will sample a random sequence from a uniform unigram 

58 model. Defaults to None. 

59 

60 start_step: 

61 The step at which to insert the initial sequence and start refinement. If 

62 `initial_sequence` is not provided, this will be set to `time_steps - 1`. 

63 

64 eval_steps: 

65 The steps at which to evaluate the loss and compute the log-probabilities. 

66 

67 return_encoder_output: 

68 Whether to return the encoder output. 

69 

70 encoder_output_reduction: 

71 The reduction to apply to the encoder output. 

72 Valid values are "mean", "max", "sum", "full". 

73 Defaults to "mean". 

74 

75 Returns: 

76 dict[str, Any]: 

77 The decoded peptides and their log-probabilities for a batch of spectra. 

78 Required keys: 

79 - "predictions": list[list[str]] 

80 - "prediction_log_probability": list[float] 

81 - "prediction_token_log_probabilities": list[list[float]] 

82 - "encoder_output": list[float] (optional) 

83 Example additional keys: 

84 - "prediction_beam_0": list[str] 

85 """ 

86 device = spectra.device 

87 sequence_length = self.model.config.max_length 

88 batch_size, num_classes = spectra.size(0), len(self.residue_set) 

89 effective_batch_size = batch_size * beam_size 

90 

91 spectra_expanded = spectra.repeat_interleave(beam_size, dim=0) 

92 spectra_padding_mask_expanded = spectra_padding_mask.repeat_interleave(beam_size, dim=0) 

93 precursors_expanded = precursors.repeat_interleave(beam_size, dim=0) 

94 

95 if initial_sequence is None: 

96 # Sample uniformly 

97 initial_distribution = Categorical(torch.ones(effective_batch_size, sequence_length, num_classes) / num_classes) 

98 sample = initial_distribution.sample().to(device) 

99 start_step = self.time_steps - 1 

100 else: 

101 sample = initial_sequence.repeat_interleave(beam_size, dim=0) 

102 

103 peptide_mask = torch.zeros(effective_batch_size, sequence_length).bool().to(device) 

104 

105 log_probs = torch.zeros((effective_batch_size, sequence_length)).to(device) 

106 # Sample through reverse process 

107 for t in range(start_step, -1, -1): 

108 times = (t * torch.ones((effective_batch_size,))).long().to(spectra.device) 

109 distribution = Categorical( 

110 logits=self.model.reverse_distribution( 

111 x_t=sample, 

112 time=times, 

113 spectra=spectra_expanded, 

114 spectra_padding_mask=spectra_padding_mask_expanded, 

115 precursors=precursors_expanded, 

116 x_padding_mask=peptide_mask, 

117 ) 

118 ) 

119 sample = distribution.sample() 

120 

121 # Calculate log-probabilities as average loss across `eval_steps` 

122 losses = [] 

123 for t in eval_steps: 

124 times = (t * torch.ones((effective_batch_size,))).long().to(spectra.device) 

125 losses.append( 

126 self.loss_function._compute_loss( 

127 x_0=sample, 

128 t=times, 

129 spectra=spectra_expanded, 

130 spectra_padding_mask=spectra_padding_mask_expanded, 

131 precursors=precursors_expanded, 

132 x_padding_mask=peptide_mask, 

133 ) 

134 ) 

135 log_probs = (-torch.stack(losses).mean(axis=0).cpu()).tolist() 

136 sequences = self._extract_predictions(sample) 

137 

138 # convert to batch_size, beam_size format 

139 completed_beams: list[list[dict[str, Any]]] = [[] for _ in range(batch_size)] 

140 

141 for idx in range(effective_batch_size): 

142 batch_idx = idx // beam_size 

143 

144 sequence = sequences[idx] 

145 sequence_mass = sum([self.model.residue_set.get_mass(residue) for residue in sequence]) 

146 # Check precursor matching 

147 precursor_mass = precursors[batch_idx, PrecursorDimension.PRECURSOR_MASS.value] 

148 remaining_mass = precursor_mass - sequence_mass - H2O_MASS 

149 mass_target_delta = mass_tolerance * precursor_mass 

150 

151 # Check if mass within range, including isotopes 

152 remaining_meets_precursor = False 

153 for j in range(0, max_isotope + 1, 1): 

154 isotope = CARBON_MASS_DELTA * j 

155 remaining_lesser_isotope = remaining_mass - isotope < mass_target_delta 

156 remaining_greater_isotope = remaining_mass - isotope > -mass_target_delta 

157 remaining_meets_precursor = remaining_meets_precursor | (remaining_lesser_isotope & remaining_greater_isotope) 

158 

159 completed_beams[batch_idx].append( 

160 { 

161 "predictions": sequences[idx], 

162 "meets_precursor": remaining_meets_precursor, 

163 "prediction_log_probability": log_probs[idx], 

164 "prediction_token_log_probabilities": log_probs[idx], 

165 } 

166 ) 

167 

168 # Sort beams by meets_precursor and prediction_log_probability 

169 # Index 0 is the best beam 

170 for batch_idx in range(batch_size): 

171 completed_beams[batch_idx].sort(key=lambda x: (x["meets_precursor"], x["prediction_log_probability"]), reverse=True) 

172 

173 result: dict[str, Any] = { 

174 "predictions": [], 

175 "meets_precursor": [], 

176 "prediction_log_probability": [], 

177 "prediction_token_log_probabilities": [None] * batch_size, 

178 } 

179 if return_beam: 

180 for i in range(beam_size): 

181 result[f"predictions_beam_{i}"] = [] 

182 result[f"prediction_log_probability_beam_{i}"] = [] 

183 result[f"prediction_token_log_probabilities_beam_{i}"] = ([None] * batch_size,) 

184 

185 for batch_idx in range(batch_size): 

186 if return_beam: 

187 for beam_idx in range(beam_size): 

188 result[f"predictions_beam_{beam_idx}"].append(completed_beams[batch_idx][beam_idx]["predictions"]) 

189 result[f"prediction_log_probability_beam_{beam_idx}"].append(completed_beams[batch_idx][beam_idx]["prediction_log_probability"]) 

190 

191 result["predictions"].append(completed_beams[batch_idx][0]["predictions"]) 

192 result["meets_precursor"].append(completed_beams[batch_idx][0]["meets_precursor"]) 

193 result["prediction_log_probability"].append(completed_beams[batch_idx][0]["prediction_log_probability"]) 

194 

195 if return_encoder_output: 

196 # Extract encoder output from cache 

197 encoder_output = self.model.transition_model.cache_cond_emb 

198 encoder_mask = self.model.transition_model.cache_cond_padding_mask 

199 

200 if encoder_output is None or encoder_mask is None: 

201 raise ValueError("Could not extract encoder output from model to return as encoder output.") 

202 

203 encoder_output = encoder_output.float().cpu() 

204 encoder_mask = (1 - encoder_mask.float()).cpu() 

205 encoder_output = encoder_output * encoder_mask.unsqueeze(-1) 

206 if encoder_output_reduction == "mean": 

207 count = encoder_mask.sum(dim=1).unsqueeze(-1).clamp(min=1) 

208 encoder_output = encoder_output.sum(dim=1) / count 

209 elif encoder_output_reduction == "max": 

210 encoder_output[encoder_output == 0] = -float("inf") 

211 encoder_output = encoder_output.max(dim=1)[0] 

212 elif encoder_output_reduction == "sum": 

213 encoder_output = encoder_output.sum(dim=1) 

214 elif encoder_output_reduction == "full": 

215 raise NotImplementedError("Full encoder output reduction is not yet implemented") 

216 else: 

217 raise ValueError(f"Invalid encoder output reduction: {encoder_output_reduction}") 

218 result["encoder_output"] = list(encoder_output.numpy()) 

219 

220 return result 

221 

222 def _extract_predictions(self, sample: Integer[Peptide, " batch"]) -> list[list[str]]: 

223 output = [] 

224 for sequence in sample: 

225 tokens = sequence.tolist() 

226 if self.residue_set.EOS_INDEX in sequence: 

227 peptide = tokens[: tokens.index(self.residue_set.EOS_INDEX)] 

228 else: 

229 peptide = tokens 

230 output.append(self.residue_set.decode(peptide, reverse=False)) # we do not reverse peptide for diffusion 

231 return output