Coverage for instanovo/common/predictor.py: 89%

393 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import os 

4import sys 

5from abc import ABCMeta, abstractmethod 

6from collections import Counter, defaultdict 

7from datetime import timedelta 

8from typing import Any, Tuple 

9 

10import numpy as np 

11import pandas as pd 

12import polars as pl 

13import torch 

14import torch.nn as nn 

15from accelerate import Accelerator 

16from accelerate.utils import DataLoaderConfiguration, InitProcessGroupKwargs 

17from datasets import Dataset, Value 

18from datasets.utils.logging import disable_progress_bar 

19from dotenv import load_dotenv 

20from omegaconf import DictConfig, OmegaConf 

21 

22from instanovo.__init__ import console, set_rank 

23from instanovo.common.dataset import DataProcessor 

24from instanovo.common.utils import Timer 

25from instanovo.constants import ANNOTATED_COLUMN, ANNOTATION_ERROR, PREDICTION_COLUMNS, MSColumns 

26from instanovo.inference import Decoder 

27from instanovo.utils.colorlogging import ColorLog 

28from instanovo.utils.data_handler import SpectrumDataFrame 

29from instanovo.utils.device_handler import validate_and_configure_device 

30from instanovo.utils.metrics import Metrics 

31from instanovo.utils.s3 import S3FileHandler 

32 

33load_dotenv() 

34 

35# Automatic rank logger 

36logger = ColorLog(console, __name__).logger 

37 

38 

39class AccelerateDeNovoPredictor(metaclass=ABCMeta): 

40 """Predictor class that uses the Accelerate library.""" 

41 

42 @property 

43 def s3(self) -> S3FileHandler: 

44 """Get the S3 file handler. 

45 

46 Returns: 

47 S3FileHandler: The S3 file handler 

48 """ 

49 return self._s3 

50 

51 def __init__( 

52 self, 

53 config: DictConfig, 

54 ) -> None: 

55 self.config = config 

56 

57 # Hide progress bar from HF datasets 

58 disable_progress_bar() 

59 

60 self.targets: list | None = None 

61 self.output_path = self.config.get("output_path", None) 

62 self.pred_df: pd.DataFrame | None = None 

63 self.results_dict: dict | None = None 

64 

65 self.prediction_tokenised_col = self.config.get("prediction_tokenised_col", "predictions_tokenised") 

66 self.prediction_col = self.config.get("prediction_col", "predictions") 

67 self.log_probs_col = self.config.get("log_probs_col", "log_probs") 

68 self.token_log_probs_col = self.config.get("token_log_probs_col", "token_log_probs") 

69 

70 # Encoder output config 

71 self.save_encoder_outputs = config.get("save_encoder_outputs", False) 

72 self.encoder_output_path = config.get("encoder_output_path", None) 

73 self.encoder_output_reduction = config.get("encoder_output_reduction", "mean") 

74 if self.save_encoder_outputs and self.encoder_output_path is None: 

75 raise ValueError( 

76 "Expected 'encoder_output_path' but found None. " 

77 "Please specify it in the config file or with the cli flag encoder-output-path=path/to/encoder_outputs.parquet", 

78 ) 

79 if self.save_encoder_outputs and self.encoder_output_reduction not in ["mean", "max", "sum", "full"]: 

80 raise ValueError( 

81 f"Invalid encoder output reduction: {self.encoder_output_reduction}. Please choose from 'mean', 'max', 'sum', or 'full'." 

82 ) 

83 if self.encoder_output_reduction == "full": 

84 raise NotImplementedError("Full encoder output reduction is not yet implemented.") 

85 

86 self.accelerator = self.setup_accelerator() 

87 

88 if self.accelerator.is_main_process: 

89 logger.info(f"Config:\n{OmegaConf.to_yaml(self.config)}") 

90 

91 # Whether to check metrics 

92 self.denovo = self.config.get("denovo", False) 

93 self._group_output: dict[str, str] | None = None 

94 self._group_mapping: dict[str, str] | None = None 

95 

96 self._s3: S3FileHandler = S3FileHandler() 

97 

98 logger.info("Loading model...") 

99 self.model, self.model_config = self.load_model() 

100 

101 self.model = self.model.to(self.accelerator.device) 

102 self.model = self.model.eval() 

103 logger.info("Model loaded.") 

104 

105 self.residue_set = self.model.residue_set 

106 self.residue_set.update_remapping(self.config.get("residue_remapping", {})) 

107 logger.info(f"Vocab: {self.residue_set.index_to_residue}") 

108 

109 logger.info("Loading dataset...") 

110 self.test_dataset = self.load_dataset() 

111 

112 logger.info(f"Data loaded: {len(self.test_dataset):,} test samples.") 

113 

114 self.test_dataloader = self.build_dataloader(self.test_dataset) 

115 logger.info("Data loader built.") 

116 

117 # Print sample batch 

118 self.print_sample_batch() 

119 

120 logger.info("Initializing decoder.") 

121 self.decoder = self.setup_decoder() 

122 

123 logger.info("Initializing metrics.") 

124 self.metrics = self.setup_metrics() 

125 

126 # Prepare accelerator 

127 self.model, self.test_dataloader = self.accelerator.prepare(self.model, self.test_dataloader) 

128 

129 self.running_loss = None 

130 self.steps_per_inference = len(self.test_dataloader) 

131 

132 logger.info(f"Total batches: {self.steps_per_inference:,d}") 

133 

134 # Final sync after setup 

135 self.accelerator.wait_for_everyone() 

136 

137 @abstractmethod 

138 def load_model(self) -> Tuple[nn.Module, DictConfig]: 

139 """Load the model.""" 

140 ... 

141 

142 @abstractmethod 

143 def setup_decoder(self) -> Decoder: 

144 """Setup the decoder.""" 

145 ... 

146 

147 @abstractmethod 

148 def setup_data_processor(self) -> DataProcessor: 

149 """Setup the data processor.""" 

150 ... 

151 

152 @abstractmethod 

153 def get_predictions(self, batch: Any) -> dict[str, Any]: 

154 """Get the predictions for a batch.""" 

155 ... 

156 

157 def postprocess_dataset(self, dataset: Dataset) -> Dataset: 

158 """Postprocess the dataset.""" 

159 return dataset 

160 

161 def load_dataset(self) -> Dataset: 

162 """Load the test dataset. 

163 

164 Returns: 

165 Dataset: 

166 The test dataset 

167 """ 

168 data_path = self.config.get("data_path", None) 

169 if OmegaConf.is_list(data_path): 

170 # If validation data is a list, we assume the data is grouped 

171 # Each list item should include a result_name, input_path, and output_path 

172 

173 _new_data_paths = [] 

174 self._group_mapping = {} # map file paths to group name 

175 self._group_output = {} # map group name to output path (for saving predictions) 

176 

177 for group in data_path: 

178 path = group.get("input_path") 

179 name = group.get("result_name") 

180 for fp in SpectrumDataFrame._convert_file_paths(path): # e.g. expands list of globs 

181 self._group_mapping[fp] = name 

182 _new_data_paths.append(path) 

183 self._group_output[name] = group.get("output_path") 

184 self.group_data_paths = data_path 

185 data_path = _new_data_paths 

186 

187 logger.info(f"Loading data from {data_path}") 

188 try: 

189 sdf = SpectrumDataFrame.load( 

190 data_path, 

191 lazy=False, 

192 is_annotated=not self.denovo, 

193 column_mapping=self.config.get("column_map", None), 

194 shuffle=False, 

195 add_spectrum_id=True, 

196 add_source_file_column=True, 

197 ) 

198 except ValueError as e: 

199 # More descriptive error message in predict mode. 

200 if str(e) == ANNOTATION_ERROR: 

201 raise ValueError( 

202 "The sequence column is missing annotations, are you trying to run de novo prediction? Add the `denovo=True` flag" 

203 ) from e 

204 else: 

205 raise 

206 

207 dataset = sdf.to_dataset(in_memory=True) 

208 

209 subset = self.config.get("subset", 1.0) 

210 if not 0 < subset <= 1: 

211 raise ValueError( 

212 f"Invalid subset value: {subset}. Must be a float greater than 0 and less than or equal to 1." # noqa: E501 

213 ) 

214 

215 original_size = len(dataset) 

216 max_charge = self.config.get("max_charge", 10) 

217 model_max_charge = self.model_config.get("max_charge", 10) 

218 if max_charge > model_max_charge: 

219 logger.warning(f"Inference has been configured with max_charge={max_charge}, but model has max_charge={model_max_charge}.") 

220 logger.warning(f"Overwriting max_charge config to model value: {model_max_charge}.") 

221 max_charge = model_max_charge 

222 

223 precursor_charge_col = MSColumns.PRECURSOR_CHARGE.value 

224 dataset = dataset.filter(lambda row: (row[precursor_charge_col] <= max_charge) and (row[precursor_charge_col] > 0)) 

225 

226 # Filter invalid sequences 

227 if not self.denovo: 

228 supported_residues = set(self.residue_set.vocab) 

229 supported_residues.update(set(self.residue_set.residue_remapping.keys())) 

230 data_residues = sdf.get_vocabulary(self.residue_set.tokenize) 

231 if len(data_residues - supported_residues) > 0: 

232 logger.warning( 

233 f"Found {len(data_residues - supported_residues):,d} unsupported residues! " 

234 "These rows will be dropped in evaluation mode. Please adjust the metrics " 

235 "calculations accordingly." 

236 ) 

237 logger.warning(f"New residues found: \n{data_residues - supported_residues}") 

238 logger.warning(f"Residues supported: \n{supported_residues}") 

239 logger.warning("Please check residue remapping if a different convention has been used.") 

240 dataset = dataset.filter( 

241 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN]))) 

242 ) 

243 

244 if len(dataset) < original_size: 

245 logger.warning( 

246 f"Found {original_size - len(dataset):,d} rows with charge > {max_charge} or <= 0. " 

247 "This could mean the charge column is missing or contains invalid values. " 

248 "These rows will be skipped." 

249 ) 

250 

251 if subset < 1.0: 

252 dataset = dataset.train_test_split(test_size=subset, seed=42)["test"] 

253 

254 if len(dataset) == 0: 

255 logger.warning("No data found, exiting.") 

256 sys.exit() 

257 

258 # Optional dataset postprocessing 

259 dataset = self.postprocess_dataset(dataset) 

260 

261 # Used to group validation outputs 

262 if self._group_mapping is not None: 

263 logger.info("Computing groups.") 

264 groups = [self._group_mapping.get(row.get("source_file"), "no_group") for row in dataset] 

265 dataset = dataset.add_column("group", groups, feature=Value("string")) 

266 

267 if self.accelerator.is_main_process: 

268 logger.info("Sequences per group:") 

269 group_counts = Counter(groups) 

270 for group, count in group_counts.items(): 

271 logger.info(f" - {group}: {count:,d}") 

272 

273 self.using_groups = True 

274 else: 

275 dataset = dataset.add_column("group", ["no_group"] * len(dataset), feature=Value("string")) 

276 self.using_groups = False 

277 

278 # Force add a unique prediction_id column 

279 # This will be used to order predictions 

280 dataset = dataset.add_column("prediction_id", np.arange(len(dataset)), feature=Value("int32")) 

281 

282 return dataset 

283 

284 def print_sample_batch(self) -> None: 

285 """Print a sample batch of the training data.""" 

286 if self.accelerator.is_main_process: 

287 # sample_batch = next(iter(self.train_dataloader)) 

288 sample_batch = next(iter(self.test_dataloader)) 

289 logger.info("Sample batch:") 

290 for key, value in sample_batch.items(): 

291 if isinstance(value, torch.Tensor): 

292 value_shape = value.shape 

293 value_type = value.dtype 

294 else: 

295 value_shape = len(value) 

296 value_type = type(value) 

297 

298 logger.info(f" - {key}: {value_type}, {value_shape}") 

299 

300 def setup_metrics(self) -> Metrics: 

301 """Setup the metrics.""" 

302 return Metrics(self.residue_set, self.config.get("max_isotope_error", 1)) 

303 

304 def setup_accelerator(self) -> Accelerator: 

305 """Setup the accelerator.""" 

306 # TODO: How do we specify device without accelerator? 

307 timeout = timedelta(seconds=self.config.get("timeout", 3600)) 

308 validate_and_configure_device(self.config) 

309 accelerator = Accelerator( 

310 cpu=self.config.get("force_cpu", False), 

311 mixed_precision="fp16" if torch.cuda.is_available() and not self.config.get("force_cpu", False) else "no", 

312 dataloader_config=DataLoaderConfiguration(split_batches=False), 

313 kwargs_handlers=[InitProcessGroupKwargs(timeout=timeout)], 

314 ) 

315 

316 device = accelerator.device # Important, this forces ranks to choose a device. 

317 

318 if accelerator.num_processes > 1: 

319 set_rank(accelerator.local_process_index) 

320 

321 if accelerator.is_main_process: 

322 logger.info(f"Python version: {sys.version}") 

323 logger.info(f"Torch version: {torch.__version__}") 

324 logger.info(f"CUDA version: {torch.version.cuda}") 

325 logger.info(f"Predicting with {accelerator.num_processes} devices") 

326 logger.info(f"Per-device batch size: {self.config['batch_size']}") 

327 

328 logger.info(f"Using device: {device}") 

329 

330 return accelerator 

331 

332 def build_dataloader(self, test_dataset: Dataset) -> torch.utils.data.DataLoader: 

333 """Setup the dataloaders.""" 

334 test_processor = self.setup_data_processor() 

335 test_processor.add_metadata_columns(["prediction_id", "group"]) 

336 

337 test_dataset = test_processor.process_dataset(test_dataset) 

338 

339 pin_memory = self.config.get("pin_memory", False) 

340 if self.accelerator.device == torch.device("cpu") or self.config.get("mps", False): 

341 pin_memory = False 

342 

343 test_dataloader = torch.utils.data.DataLoader( 

344 test_dataset, 

345 batch_size=self.config["batch_size"], 

346 collate_fn=test_processor.collate_fn, 

347 num_workers=self.config.get("num_workers", 8), 

348 pin_memory=pin_memory, 

349 prefetch_factor=self.config.get("prefetch_factor", 2), 

350 drop_last=False, 

351 ) 

352 return test_dataloader 

353 

354 def predict(self) -> pd.DataFrame: 

355 """Predict the test dataset.""" 

356 all_predictions: dict[str, list] = defaultdict(list) 

357 all_encoder_outputs: list[np.ndarray] = [] 

358 test_step = 0 

359 

360 logger.info("Predicting...") 

361 inference_timer = Timer(self.steps_per_inference) 

362 print_batch_size = True 

363 for i, batch in enumerate(self.test_dataloader): 

364 if print_batch_size: 

365 logger.info(f"Batch {i} shape: {batch['spectra'].shape[0]}") 

366 print_batch_size = False 

367 

368 # Implementation specific 

369 with torch.no_grad(), self.accelerator.autocast(): 

370 batch_predictions = self.get_predictions(batch) 

371 

372 # Pass through prediction_id and group columns 

373 # prediction_id is automatically cast to tensor 

374 batch_predictions["prediction_id"] = [x.item() for x in batch["prediction_id"]] 

375 batch_predictions["group"] = batch["group"] 

376 

377 # Some outputs are required from get_predictions 

378 for k in PREDICTION_COLUMNS: 

379 if k not in batch_predictions: 

380 raise ValueError(f"Prediction column {k} not found in batch predictions.") 

381 all_predictions[k].extend(batch_predictions[k]) 

382 

383 if "encoder_output" in batch_predictions: 

384 # Always ensure it is removed with pop even if we do not save it. 

385 encoder_output = batch_predictions.pop("encoder_output") 

386 if self.save_encoder_outputs: 

387 all_encoder_outputs.extend(encoder_output) 

388 

389 # Additional prediction info 

390 if self.config.get("save_all_predictions", False): 

391 for k, v in batch_predictions.items(): 

392 if k in PREDICTION_COLUMNS: 

393 continue 

394 all_predictions[k].extend(v) 

395 

396 test_step += 1 

397 inference_timer.step() 

398 

399 if (i + 1) % self.config.get("log_interval", 50) == 0 or (i + 1) == self.steps_per_inference: 

400 logger.info( 

401 f"[Batch {i + 1:05d}/{self.steps_per_inference:05d}] " 

402 f"[{inference_timer.get_time_str()}/{inference_timer.get_eta_str()}] " # noqa: E501 

403 f"{inference_timer.get_step_time_rate_str()}: " 

404 ) 

405 

406 logger.info(f"Time taken for {self.config.get('data_path', None)} is {inference_timer.get_delta():.1f} seconds") 

407 

408 logger.info("Prediction complete.") 

409 self.accelerator.wait_for_everyone() 

410 

411 if self.accelerator.num_processes > 1: 

412 logger.info("Gathering predictions from all processes...") 

413 

414 # Broadcast all predictions to all processes 

415 for key, value in all_predictions.items(): 

416 all_predictions[key] = self.accelerator.gather_for_metrics(value, use_gather_object=True) 

417 

418 if self.save_encoder_outputs: 

419 all_encoder_outputs = self.accelerator.gather_for_metrics(all_encoder_outputs, use_gather_object=True) 

420 

421 if self.accelerator.is_main_process: 

422 pred_df = self.predictions_to_df(all_predictions) 

423 pred_df = self.postprocess_predictions(pred_df) 

424 

425 results_dict = None 

426 if not self.denovo: 

427 logger.info("Calculating metrics...") 

428 

429 results_dict = self.calculate_metrics(pred_df) 

430 

431 self.save_predictions(pred_df, results_dict) 

432 

433 if self.save_encoder_outputs: 

434 self.save_encoder_outputs_to_parquet(pred_df["spectrum_id"].tolist(), all_encoder_outputs) 

435 else: 

436 pred_df = None 

437 

438 return pred_df 

439 

440 def _tokens_to_string(self, tokens: list[str] | None) -> str: 

441 """Convert a list of tokens to a ProForma compliant string.""" 

442 if tokens is None: 

443 return "" 

444 peptide = "" 

445 if len(tokens) > 1 and not tokens[0][0].isalpha(): 

446 # Assume n-terminal 

447 peptide = tokens[0] + "-" 

448 tokens = tokens[1:] 

449 return peptide + "".join(tokens) 

450 

451 def predictions_to_df(self, predictions: dict[str, list]) -> pd.DataFrame: 

452 """Convert the predictions to a pandas DataFrame. 

453 

454 Args: 

455 predictions: The predictions dictionary 

456 

457 Returns: 

458 pd.DataFrame: The predictions dataframe 

459 """ 

460 index_cols = self.config.get("index_columns", ["precursor_mz", "precursor_charge"]) 

461 index_cols = [x for x in index_cols if x in self.test_dataset.column_names] 

462 index_cols.append("prediction_id") 

463 index_df = self.test_dataset.to_pandas()[index_cols] 

464 

465 pred_df = pd.DataFrame(predictions) 

466 # Drop duplicates caused by padding multiple processes 

467 pred_df = pred_df.drop_duplicates(subset=["prediction_id"], keep="first") 

468 

469 pred_df = index_df.merge(pred_df, on="prediction_id", how="left") 

470 

471 # Some column processing 

472 pred_df["predictions_tokenised"] = pred_df["predictions"].map(lambda x: ", ".join(x)) 

473 pred_df["predictions"] = pred_df["predictions"].map(self._tokens_to_string) 

474 pred_df["targets"] = pred_df["targets"].map(self._tokens_to_string) 

475 pred_df["delta_mass_ppm"] = pred_df.apply( 

476 lambda row: np.min(np.abs(self.metrics.matches_precursor(row["predictions_tokenised"], row["precursor_mz"], row["precursor_charge"])[1])), 

477 axis=1, 

478 ) 

479 

480 pred_df = pred_df.rename( 

481 columns={ 

482 "predictions": self.prediction_col, 

483 "predictions_tokenised": self.prediction_tokenised_col, 

484 "prediction_log_probability": self.log_probs_col, 

485 "prediction_token_log_probabilities": self.token_log_probs_col, 

486 } 

487 ) 

488 

489 if self.denovo: 

490 pred_df.drop(columns=["targets"], inplace=True) 

491 

492 return pred_df 

493 

494 def postprocess_predictions(self, pred_df: pd.DataFrame) -> pd.DataFrame: 

495 """Postprocess the predictions. 

496 

497 Optionally, this can be used to modify the predictions, eg. ensembling. 

498 By default, this does nothing. 

499 

500 Args: 

501 pred_df: The predictions dataframe 

502 

503 Returns: 

504 pd.DataFrame: The postprocessed predictions dataframe 

505 """ 

506 return pred_df 

507 

508 def calculate_metrics( 

509 self, 

510 pred_df: pd.DataFrame, 

511 ) -> dict[str, Any] | None: 

512 """Calculate the metrics. 

513 

514 Args: 

515 pred_df: The predictions dataframe 

516 

517 Returns: 

518 dict[str, Any] | None: The results dictionary containing the metrics 

519 """ 

520 predictions = pred_df[self.prediction_tokenised_col].copy() 

521 targets = pred_df["targets"] 

522 log_probs = pred_df[self.log_probs_col] 

523 delta_mass_ppm = pred_df["delta_mass_ppm"] 

524 

525 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, predictions) 

526 aa_er = self.metrics.compute_aa_er(targets, predictions) 

527 auc = self.metrics.calc_auc( 

528 targets, 

529 predictions, 

530 np.exp(log_probs), 

531 ) 

532 

533 logger.info("Performance:") 

534 logger.info(f" aa_er {aa_er:.5f}") 

535 logger.info(f" aa_prec {aa_prec:.5f}") 

536 logger.info(f" aa_recall {aa_recall:.5f}") 

537 logger.info(f" pep_prec {pep_prec:.5f}") 

538 logger.info(f" pep_recall {pep_recall:.5f}") 

539 logger.info(f" auc {auc:.5f}") 

540 

541 fdr = self.config.get("filter_fdr_threshold", None) 

542 if fdr: 

543 _, threshold = self.metrics.find_recall_at_fdr( 

544 targets, 

545 predictions, 

546 np.exp(log_probs), 

547 fdr=fdr, 

548 ) 

549 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall( 

550 targets, 

551 predictions, 

552 np.exp(log_probs), 

553 threshold=threshold, 

554 ) 

555 logger.info(f"Performance at {fdr * 100:.1f}% FDR:") 

556 logger.info(f" aa_prec {aa_prec:.5f}") 

557 logger.info(f" aa_recall {aa_recall:.5f}") 

558 logger.info(f" pep_prec {pep_prec:.5f}") 

559 logger.info(f" pep_recall {pep_recall:.5f}") 

560 logger.info(f" confidence {threshold:.5f}") 

561 

562 filter_precursor_ppm = self.config.get("filter_precursor_ppm", None) 

563 if filter_precursor_ppm and delta_mass_ppm is not None: 

564 idx = delta_mass_ppm < filter_precursor_ppm # type: ignore 

565 logger.info(f"Performance with filtering at {filter_precursor_ppm} ppm delta mass:") 

566 if np.sum(idx) > 0: 

567 filtered_preds = pd.Series(predictions) 

568 filtered_preds[~idx] = "" 

569 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, filtered_preds) 

570 logger.info(f" aa_prec {aa_prec:.5f}") 

571 logger.info(f" aa_recall {aa_recall:.5f}") 

572 logger.info(f" pep_prec {pep_prec:.5f}") 

573 logger.info(f" pep_recall {pep_recall:.5f}") 

574 logger.info(f"Rows filtered: {len(predictions) - np.sum(idx)} ({(len(predictions) - np.sum(idx)) / len(predictions) * 100:.2f}%)") 

575 if np.sum(idx) < 1000: 

576 logger.info(f"Metrics calculated on a small number of samples ({np.sum(idx)}), interpret with care!") 

577 else: 

578 logger.info("No predictions met criteria, skipping metrics.") 

579 

580 model_confidence_no_pred = self.config.get("filter_confidence", None) 

581 if model_confidence_no_pred: 

582 idx = np.exp(log_probs) > model_confidence_no_pred 

583 logger.info(f"Performance with filtering confidence < {model_confidence_no_pred}") 

584 if np.sum(idx) > 0: 

585 filtered_preds = pd.Series(predictions) 

586 filtered_preds[~idx] = "" 

587 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, filtered_preds) 

588 logger.info(f" aa_prec {aa_prec:.5f}") 

589 logger.info(f" aa_recall {aa_recall:.5f}") 

590 logger.info(f" pep_prec {pep_prec:.5f}") 

591 logger.info(f" pep_recall {pep_recall:.5f}") 

592 logger.info(f"Rows filtered: {len(predictions) - np.sum(idx)} ({(len(predictions) - np.sum(idx)) / len(predictions) * 100:.2f}%)") 

593 if np.sum(idx) < 1000: 

594 logger.info(f"Metrics calculated on a small number of samples ({np.sum(idx)}), interpret with care!") 

595 else: 

596 logger.info("No predictions met criteria, skipping metrics.") 

597 

598 # Evaluate individual result files 

599 if self.using_groups and not self.denovo: 

600 logger.info("Evaluating individual result files.") 

601 # TODO Handle better with pred_df 

602 _preds = pd.Series(predictions) 

603 _targs = pd.Series(targets) 

604 _probs = pd.Series(log_probs) 

605 

606 # TODO Make this more generic 

607 results_dict: dict[str, Any] = { 

608 "run_name": self.config.get("run_name"), 

609 "instanovo_model": self.config.get("instanovo_model"), 

610 "num_beams": self.config.get("num_beams", 1), 

611 "use_knapsack": self.config.get("use_knapsack", False), 

612 } 

613 for group in pred_df["group"].unique(): 

614 if group == "no_group": 

615 continue 

616 idx = pred_df["group"] == group 

617 _group_preds = _preds[idx].reset_index(drop=True) 

618 _group_targs = _targs[idx].reset_index(drop=True) 

619 _group_probs = _probs[idx].reset_index(drop=True) 

620 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(_group_targs, _group_preds) 

621 aa_er = self.metrics.compute_aa_er(_group_targs, _group_preds) 

622 auc = self.metrics.calc_auc(_group_targs, _group_preds, _group_probs) 

623 

624 results_dict.update( 

625 { 

626 f"{group}_aa_prec": [aa_prec], 

627 f"{group}_aa_recall": [aa_recall], 

628 f"{group}_pep_recall": [pep_recall], 

629 f"{group}_pep_prec": [pep_prec], 

630 f"{group}_aa_er": [aa_er], 

631 f"{group}_auc": [auc], 

632 } 

633 ) 

634 

635 fdr = self.config.get("filter_fdr_threshold", None) 

636 if fdr: 

637 _, threshold = self.metrics.find_recall_at_fdr(_group_targs, _group_preds, np.exp(_group_probs), fdr=fdr) 

638 _, _, pep_recall_at_fdr, _ = self.metrics.compute_precision_recall( 

639 _group_targs, 

640 _group_preds, 

641 np.exp(_group_probs), 

642 threshold=threshold, 

643 ) 

644 

645 results_dict.update( 

646 { 

647 f"{group}_pep_recall_at_{fdr:.3f}_fdr": [pep_recall_at_fdr], 

648 } 

649 ) 

650 return results_dict 

651 return None 

652 

653 def save_predictions(self, pred_df: pd.DataFrame, results_dict: dict[str, list] | None = None) -> None: 

654 """Save the predictions to a file. 

655 

656 Args: 

657 pred_df: The predictions dataframe 

658 results_dict: The results dictionary containing the metrics 

659 """ 

660 # Save metrics to a file 

661 if self.using_groups and not self.denovo and results_dict is not None: 

662 result_path = self.config.get("result_file_path") 

663 logger.info(f"Saving metrics to {result_path}.") 

664 local_path = self.s3.get_local_path(result_path, missing_ok=True) 

665 if local_path is not None and os.path.exists(local_path): 

666 results_df = pd.read_csv(local_path) 

667 results_df = pd.concat([results_df, pd.DataFrame(results_dict)], ignore_index=True, join="outer") 

668 else: 

669 results_df = pd.DataFrame(results_dict) 

670 

671 self.s3.upload_to_s3_wrapper(results_df.to_csv, result_path, index=False) 

672 

673 # Save individual result files per group 

674 if self.using_groups and self._group_output is not None and pred_df is not None: 

675 logger.info("Saving individual result files per group.") 

676 for group in pred_df["group"].unique(): 

677 idx = pred_df["group"] == group 

678 if self._group_output.get(group) is not None: 

679 self.s3.upload_to_s3_wrapper(pred_df[idx].to_csv, self._group_output[group], index=False) 

680 

681 # Save output 

682 if self.output_path is not None and pred_df is not None: 

683 logger.info(f"Saving predictions to {self.output_path}.") 

684 self.s3.upload_to_s3_wrapper(pred_df.to_csv, self.output_path, index=False) 

685 logger.info(f"Predictions saved to {self.output_path}") 

686 

687 # Upload to Aichor 

688 if S3FileHandler._aichor_enabled() and not self.output_path.startswith("s3://"): 

689 self.s3.upload(self.output_path, S3FileHandler.convert_to_s3_output(self.output_path)) 

690 

691 def save_encoder_outputs_to_parquet(self, spectrum_ids: list[str], encoder_outputs: list[np.ndarray]) -> None: 

692 """Save the encoder outputs to a file. 

693 

694 Args: 

695 encoder_outputs: The encoder outputs 

696 spectrum_ids: The spectrum ids 

697 """ 

698 if len(encoder_outputs) == 0: 

699 logger.warning(f"No encoder outputs were returned by the decoder {type(self.decoder)}.") 

700 logger.warning("Skipping encoder output saving.") 

701 return 

702 

703 if self.encoder_output_path is not None: 

704 encoder_outputs_fp32 = np.stack(encoder_outputs).astype(np.float32) 

705 encoder_output_df = pl.DataFrame( 

706 {"spectrum_id": spectrum_ids, **{f"spectrum_encoding_{i}": encoder_outputs_fp32[:, i] for i in range(encoder_outputs_fp32.shape[1])}} 

707 ) 

708 

709 logger.info(f"Saving encoder outputs to {self.encoder_output_path}.") 

710 self.s3.upload_to_s3_wrapper(encoder_output_df.write_parquet, self.encoder_output_path)