Coverage for instanovo/utils/metrics.py: 97%

124 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import bisect 

4 

5import jiwer 

6import numpy as np 

7 

8from instanovo.constants import CARBON_MASS_DELTA 

9from instanovo.utils.residues import ResidueSet 

10 

11 

12class Metrics: 

13 """Peptide metrics class.""" 

14 

15 def __init__( 

16 self, 

17 residue_set: ResidueSet, 

18 isotope_error_range: list[int] | int | None = None, 

19 cum_mass_threshold: float = 0.5, 

20 ind_mass_threshold: float = 0.1, 

21 ) -> None: 

22 self.residue_set = residue_set 

23 if isotope_error_range is None: 

24 self.isotope_error_range = [0, 1] 

25 elif isinstance(isotope_error_range, int): 

26 self.isotope_error_range = [0, isotope_error_range] 

27 else: 

28 self.isotope_error_range = isotope_error_range 

29 self.cum_mass_threshold = cum_mass_threshold 

30 self.ind_mass_threshold = ind_mass_threshold 

31 

32 def matches_precursor( 

33 self, 

34 seq: str | list[str] | None, 

35 prec_mz: float, 

36 prec_charge: int, 

37 prec_tol: float = 50, 

38 ) -> tuple[bool, list[float]]: 

39 """Check if a sequence matches the precursor mass within some tolerance.""" 

40 if seq is None: 

41 return False, [] 

42 seq_mz = self._mass(seq, charge=prec_charge) 

43 delta_mass_ppm = [ 

44 self._calc_mass_error(seq_mz, prec_mz, prec_charge, isotope) 

45 for isotope in range( 

46 self.isotope_error_range[0], 

47 self.isotope_error_range[1] + 1, 

48 ) 

49 ] 

50 return any(abs(d) < prec_tol for d in delta_mass_ppm), delta_mass_ppm 

51 

52 def compute_aa_er( 

53 self, 

54 peptides_truth: list[str] | list[list[str]], 

55 peptides_predicted: list[str] | list[list[str]], 

56 ) -> float: 

57 """Compute amino-acid level error-rate.""" 

58 # Ensure amino acids are separated 

59 peptides_truth = self._split_sequences(peptides_truth) 

60 peptides_predicted = self._split_sequences(peptides_predicted) 

61 

62 return float( 

63 jiwer.wer( 

64 [" ".join(x).replace("I", "L") for x in peptides_truth], 

65 [" ".join(x).replace("I", "L") for x in peptides_predicted], 

66 ) 

67 ) 

68 

69 # Adapted from https://github.com/Noble-Lab/casanovo/blob/main/casanovo/denovo/evaluate.py 

70 def compute_precision_recall( 

71 self, 

72 targets: list[str] | list[list[str]], 

73 predictions: list[str] | list[list[str]], 

74 confidence: list[float] | None = None, 

75 threshold: float | None = None, 

76 ) -> tuple[float, float, float, float]: 

77 """Calculate precision and recall at peptide- and AA-level. 

78 

79 Args: 

80 targets: Target peptides. 

81 predictions: Model predicted peptides. 

82 confidence: Optional model confidence. 

83 threshold: Optional confidence threshold. 

84 """ 

85 targets = self._split_sequences(targets) 

86 predictions = self._split_sequences(predictions) 

87 

88 n_targ_aa, n_pred_aa, n_match_aa = 0, 0, 0 

89 n_pred_pep, n_match_pep = 0, 0 

90 

91 if confidence is None or threshold is None: 

92 threshold = 0 

93 confidence = np.ones(len(predictions)) # type: ignore 

94 

95 for i in range(len(targets)): 

96 targ = self._split_peptide(targets[i]) 

97 pred = self._split_peptide(predictions[i]) 

98 conf = confidence[i] # type: ignore 

99 

100 # Legacy for old regex, may be removed 

101 if len(pred) > 0 and pred[0] == "": 

102 pred = [] 

103 

104 n_targ_aa += len(targ) 

105 if conf >= threshold and len(pred) > 0: 

106 n_pred_aa += len(pred) 

107 n_pred_pep += 1 

108 

109 # pred = [x.replace('I', 'L') for x in pred] 

110 # n_match_aa += np.sum([m[0]==' ' for m in difflib.ndiff(targ,pred)]) 

111 n_match = self._novor_match(targ, pred) 

112 n_match_aa += n_match 

113 

114 if len(pred) == len(targ) and len(targ) == n_match: 

115 n_match_pep += 1 

116 

117 pep_recall = n_match_pep / len(targets) 

118 aa_recall = n_match_aa / n_targ_aa 

119 

120 if n_pred_pep == 0: 

121 pep_precision = 1.0 

122 aa_prec = 1.0 

123 else: 

124 pep_precision = n_match_pep / n_pred_pep 

125 aa_prec = n_match_aa / n_pred_aa 

126 

127 return aa_prec, aa_recall, pep_recall, pep_precision 

128 

129 def calc_auc( 

130 self, 

131 targs: list[str] | list[list[str]], 

132 preds: list[str] | list[list[str]], 

133 conf: list[float], 

134 ) -> float: 

135 """Calculate the peptide-level AUC.""" 

136 x, y = self._get_pr_curve(targs, preds, conf) 

137 recall, precision = np.array(x)[::-1], np.array(y)[::-1] 

138 

139 width = recall[1:] - recall[:-1] 

140 height = np.minimum(precision[1:], precision[:-1]) 

141 top = np.maximum(precision[1:], precision[:-1]) 

142 side = top - height 

143 return (width * height).sum() + 0.5 * (side * width).sum() # type: ignore 

144 

145 def find_recall_at_fdr( 

146 self, 

147 targs: list[str] | list[list[str]], 

148 preds: list[str] | list[list[str]], 

149 conf: list[float], 

150 fdr: float = 0.05, 

151 ) -> tuple[float, float]: 

152 """Get model recall and threshold for specified FDR.""" 

153 conf = np.array(conf) 

154 order = conf.argsort()[::-1] 

155 matches = np.array(self._get_peptide_matches(targs, preds)) 

156 matches = matches[order] 

157 conf = conf[order] 

158 

159 csum = np.cumsum(matches) 

160 precision = csum / (np.arange(len(matches)) + 1) 

161 recall = csum / len(matches) 

162 

163 # if precision never greater than FDR 

164 if all(precision < (1 - fdr)): 

165 # recall = 0, threshold = 1 

166 return 0.0, 1.0 

167 

168 # bisect requires ascending order 

169 idx = len(precision) - bisect.bisect_right(precision[::-1], 1 - fdr) - 1 

170 return recall[idx], conf[idx] 

171 

172 def _split_sequences(self, seq: list[str] | list[list[str]]) -> list[list[str]]: 

173 return [self.residue_set.tokenize(x) if isinstance(x, str) else x for x in seq] 

174 

175 def _split_peptide(self, peptide: str | list[str]) -> list[str]: 

176 if not isinstance(peptide, str): 

177 return peptide 

178 return self.residue_set.tokenize(peptide) # type: ignore 

179 

180 def _get_pr_curve( 

181 self, 

182 targs: list[str] | list[list[str]], 

183 preds: list[str] | list[list[str]], 

184 conf: list[float], 

185 N: int = 20, # noqa: N803 

186 ) -> tuple[list[float], list[float]]: 

187 x, y = [], [] 

188 t_idx = np.argsort(np.array(conf)) 

189 t_idx = t_idx[~conf[t_idx].isna()] 

190 t_idx = list(t_idx[(t_idx.shape[0] * np.arange(N) / N).astype(int)]) + [t_idx[-1]] 

191 for t in conf[t_idx]: # type: ignore 

192 _, _, recall, precision = self.compute_precision_recall(targs, preds, conf, t) 

193 x.append(recall) 

194 y.append(precision) 

195 return x, y 

196 

197 def _mass(self, seq: str | list[str], charge: int | None = None) -> float: 

198 """Calculate a peptide's mass or m/z.""" 

199 seq = self._split_peptide(seq) 

200 return self.residue_set.get_sequence_mass(seq, charge) # type: ignore 

201 

202 def _calc_mass_error(self, mz_theoretical: float, mz_measured: float, charge: int, isotope: int = 0) -> float: 

203 """Calculate the mass error between theoretical and actual mz in ppm.""" 

204 return float((mz_theoretical - (mz_measured - isotope * CARBON_MASS_DELTA / charge)) / mz_measured * 10**6) 

205 

206 # Adapted from https://github.com/Noble-Lab/casanovo/blob/main/casanovo/denovo/evaluate.py 

207 def _novor_match( 

208 self, 

209 a: list[str], 

210 b: list[str], 

211 ) -> int: 

212 """Number of AA matches with novor method.""" 

213 n = 0 

214 

215 mass_a: list[float] = [self.residue_set.get_mass(x) for x in a] 

216 mass_b: list[float] = [self.residue_set.get_mass(x) for x in b] 

217 cum_mass_a = np.cumsum(mass_a) 

218 cum_mass_b = np.cumsum(mass_b) 

219 

220 i, j = 0, 0 

221 while i < len(a) and j < len(b): 

222 if abs(cum_mass_a[i] - cum_mass_b[j]) < self.cum_mass_threshold: 

223 n += int(abs(mass_a[i] - mass_b[j]) < self.ind_mass_threshold) 

224 i += 1 

225 j += 1 

226 elif cum_mass_b[j] > cum_mass_a[i]: 

227 i += 1 

228 else: 

229 j += 1 

230 return n 

231 

232 def _get_peptide_matches( 

233 self, 

234 targets: list[str] | list[list[str]], 

235 predictions: list[str] | list[list[str]], 

236 ) -> list[bool]: 

237 matches: list[bool] = [] 

238 for i in range(len(targets)): 

239 targ = self._split_peptide(targets[i]) 

240 pred = self._split_peptide(predictions[i]) 

241 if len(pred) > 0 and pred[0] == "": 

242 pred = [] 

243 n_match = self._novor_match(targ, pred) 

244 matches.append(len(pred) == len(targ) and len(targ) == n_match) 

245 return matches