Coverage for instanovo/common/predictor.py: 89%
393 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import os
4import sys
5from abc import ABCMeta, abstractmethod
6from collections import Counter, defaultdict
7from datetime import timedelta
8from typing import Any, Tuple
10import numpy as np
11import pandas as pd
12import polars as pl
13import torch
14import torch.nn as nn
15from accelerate import Accelerator
16from accelerate.utils import DataLoaderConfiguration, InitProcessGroupKwargs
17from datasets import Dataset, Value
18from datasets.utils.logging import disable_progress_bar
19from dotenv import load_dotenv
20from omegaconf import DictConfig, OmegaConf
22from instanovo.__init__ import console, set_rank
23from instanovo.common.dataset import DataProcessor
24from instanovo.common.utils import Timer
25from instanovo.constants import ANNOTATED_COLUMN, ANNOTATION_ERROR, PREDICTION_COLUMNS, MSColumns
26from instanovo.inference import Decoder
27from instanovo.utils.colorlogging import ColorLog
28from instanovo.utils.data_handler import SpectrumDataFrame
29from instanovo.utils.device_handler import validate_and_configure_device
30from instanovo.utils.metrics import Metrics
31from instanovo.utils.s3 import S3FileHandler
33load_dotenv()
35# Automatic rank logger
36logger = ColorLog(console, __name__).logger
39class AccelerateDeNovoPredictor(metaclass=ABCMeta):
40 """Predictor class that uses the Accelerate library."""
42 @property
43 def s3(self) -> S3FileHandler:
44 """Get the S3 file handler.
46 Returns:
47 S3FileHandler: The S3 file handler
48 """
49 return self._s3
51 def __init__(
52 self,
53 config: DictConfig,
54 ) -> None:
55 self.config = config
57 # Hide progress bar from HF datasets
58 disable_progress_bar()
60 self.targets: list | None = None
61 self.output_path = self.config.get("output_path", None)
62 self.pred_df: pd.DataFrame | None = None
63 self.results_dict: dict | None = None
65 self.prediction_tokenised_col = self.config.get("prediction_tokenised_col", "predictions_tokenised")
66 self.prediction_col = self.config.get("prediction_col", "predictions")
67 self.log_probs_col = self.config.get("log_probs_col", "log_probs")
68 self.token_log_probs_col = self.config.get("token_log_probs_col", "token_log_probs")
70 # Encoder output config
71 self.save_encoder_outputs = config.get("save_encoder_outputs", False)
72 self.encoder_output_path = config.get("encoder_output_path", None)
73 self.encoder_output_reduction = config.get("encoder_output_reduction", "mean")
74 if self.save_encoder_outputs and self.encoder_output_path is None:
75 raise ValueError(
76 "Expected 'encoder_output_path' but found None. "
77 "Please specify it in the config file or with the cli flag encoder-output-path=path/to/encoder_outputs.parquet",
78 )
79 if self.save_encoder_outputs and self.encoder_output_reduction not in ["mean", "max", "sum", "full"]:
80 raise ValueError(
81 f"Invalid encoder output reduction: {self.encoder_output_reduction}. Please choose from 'mean', 'max', 'sum', or 'full'."
82 )
83 if self.encoder_output_reduction == "full":
84 raise NotImplementedError("Full encoder output reduction is not yet implemented.")
86 self.accelerator = self.setup_accelerator()
88 if self.accelerator.is_main_process:
89 logger.info(f"Config:\n{OmegaConf.to_yaml(self.config)}")
91 # Whether to check metrics
92 self.denovo = self.config.get("denovo", False)
93 self._group_output: dict[str, str] | None = None
94 self._group_mapping: dict[str, str] | None = None
96 self._s3: S3FileHandler = S3FileHandler()
98 logger.info("Loading model...")
99 self.model, self.model_config = self.load_model()
101 self.model = self.model.to(self.accelerator.device)
102 self.model = self.model.eval()
103 logger.info("Model loaded.")
105 self.residue_set = self.model.residue_set
106 self.residue_set.update_remapping(self.config.get("residue_remapping", {}))
107 logger.info(f"Vocab: {self.residue_set.index_to_residue}")
109 logger.info("Loading dataset...")
110 self.test_dataset = self.load_dataset()
112 logger.info(f"Data loaded: {len(self.test_dataset):,} test samples.")
114 self.test_dataloader = self.build_dataloader(self.test_dataset)
115 logger.info("Data loader built.")
117 # Print sample batch
118 self.print_sample_batch()
120 logger.info("Initializing decoder.")
121 self.decoder = self.setup_decoder()
123 logger.info("Initializing metrics.")
124 self.metrics = self.setup_metrics()
126 # Prepare accelerator
127 self.model, self.test_dataloader = self.accelerator.prepare(self.model, self.test_dataloader)
129 self.running_loss = None
130 self.steps_per_inference = len(self.test_dataloader)
132 logger.info(f"Total batches: {self.steps_per_inference:,d}")
134 # Final sync after setup
135 self.accelerator.wait_for_everyone()
137 @abstractmethod
138 def load_model(self) -> Tuple[nn.Module, DictConfig]:
139 """Load the model."""
140 ...
142 @abstractmethod
143 def setup_decoder(self) -> Decoder:
144 """Setup the decoder."""
145 ...
147 @abstractmethod
148 def setup_data_processor(self) -> DataProcessor:
149 """Setup the data processor."""
150 ...
152 @abstractmethod
153 def get_predictions(self, batch: Any) -> dict[str, Any]:
154 """Get the predictions for a batch."""
155 ...
157 def postprocess_dataset(self, dataset: Dataset) -> Dataset:
158 """Postprocess the dataset."""
159 return dataset
161 def load_dataset(self) -> Dataset:
162 """Load the test dataset.
164 Returns:
165 Dataset:
166 The test dataset
167 """
168 data_path = self.config.get("data_path", None)
169 if OmegaConf.is_list(data_path):
170 # If validation data is a list, we assume the data is grouped
171 # Each list item should include a result_name, input_path, and output_path
173 _new_data_paths = []
174 self._group_mapping = {} # map file paths to group name
175 self._group_output = {} # map group name to output path (for saving predictions)
177 for group in data_path:
178 path = group.get("input_path")
179 name = group.get("result_name")
180 for fp in SpectrumDataFrame._convert_file_paths(path): # e.g. expands list of globs
181 self._group_mapping[fp] = name
182 _new_data_paths.append(path)
183 self._group_output[name] = group.get("output_path")
184 self.group_data_paths = data_path
185 data_path = _new_data_paths
187 logger.info(f"Loading data from {data_path}")
188 try:
189 sdf = SpectrumDataFrame.load(
190 data_path,
191 lazy=False,
192 is_annotated=not self.denovo,
193 column_mapping=self.config.get("column_map", None),
194 shuffle=False,
195 add_spectrum_id=True,
196 add_source_file_column=True,
197 )
198 except ValueError as e:
199 # More descriptive error message in predict mode.
200 if str(e) == ANNOTATION_ERROR:
201 raise ValueError(
202 "The sequence column is missing annotations, are you trying to run de novo prediction? Add the `denovo=True` flag"
203 ) from e
204 else:
205 raise
207 dataset = sdf.to_dataset(in_memory=True)
209 subset = self.config.get("subset", 1.0)
210 if not 0 < subset <= 1:
211 raise ValueError(
212 f"Invalid subset value: {subset}. Must be a float greater than 0 and less than or equal to 1." # noqa: E501
213 )
215 original_size = len(dataset)
216 max_charge = self.config.get("max_charge", 10)
217 model_max_charge = self.model_config.get("max_charge", 10)
218 if max_charge > model_max_charge:
219 logger.warning(f"Inference has been configured with max_charge={max_charge}, but model has max_charge={model_max_charge}.")
220 logger.warning(f"Overwriting max_charge config to model value: {model_max_charge}.")
221 max_charge = model_max_charge
223 precursor_charge_col = MSColumns.PRECURSOR_CHARGE.value
224 dataset = dataset.filter(lambda row: (row[precursor_charge_col] <= max_charge) and (row[precursor_charge_col] > 0))
226 # Filter invalid sequences
227 if not self.denovo:
228 supported_residues = set(self.residue_set.vocab)
229 supported_residues.update(set(self.residue_set.residue_remapping.keys()))
230 data_residues = sdf.get_vocabulary(self.residue_set.tokenize)
231 if len(data_residues - supported_residues) > 0:
232 logger.warning(
233 f"Found {len(data_residues - supported_residues):,d} unsupported residues! "
234 "These rows will be dropped in evaluation mode. Please adjust the metrics "
235 "calculations accordingly."
236 )
237 logger.warning(f"New residues found: \n{data_residues - supported_residues}")
238 logger.warning(f"Residues supported: \n{supported_residues}")
239 logger.warning("Please check residue remapping if a different convention has been used.")
240 dataset = dataset.filter(
241 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN])))
242 )
244 if len(dataset) < original_size:
245 logger.warning(
246 f"Found {original_size - len(dataset):,d} rows with charge > {max_charge} or <= 0. "
247 "This could mean the charge column is missing or contains invalid values. "
248 "These rows will be skipped."
249 )
251 if subset < 1.0:
252 dataset = dataset.train_test_split(test_size=subset, seed=42)["test"]
254 if len(dataset) == 0:
255 logger.warning("No data found, exiting.")
256 sys.exit()
258 # Optional dataset postprocessing
259 dataset = self.postprocess_dataset(dataset)
261 # Used to group validation outputs
262 if self._group_mapping is not None:
263 logger.info("Computing groups.")
264 groups = [self._group_mapping.get(row.get("source_file"), "no_group") for row in dataset]
265 dataset = dataset.add_column("group", groups, feature=Value("string"))
267 if self.accelerator.is_main_process:
268 logger.info("Sequences per group:")
269 group_counts = Counter(groups)
270 for group, count in group_counts.items():
271 logger.info(f" - {group}: {count:,d}")
273 self.using_groups = True
274 else:
275 dataset = dataset.add_column("group", ["no_group"] * len(dataset), feature=Value("string"))
276 self.using_groups = False
278 # Force add a unique prediction_id column
279 # This will be used to order predictions
280 dataset = dataset.add_column("prediction_id", np.arange(len(dataset)), feature=Value("int32"))
282 return dataset
284 def print_sample_batch(self) -> None:
285 """Print a sample batch of the training data."""
286 if self.accelerator.is_main_process:
287 # sample_batch = next(iter(self.train_dataloader))
288 sample_batch = next(iter(self.test_dataloader))
289 logger.info("Sample batch:")
290 for key, value in sample_batch.items():
291 if isinstance(value, torch.Tensor):
292 value_shape = value.shape
293 value_type = value.dtype
294 else:
295 value_shape = len(value)
296 value_type = type(value)
298 logger.info(f" - {key}: {value_type}, {value_shape}")
300 def setup_metrics(self) -> Metrics:
301 """Setup the metrics."""
302 return Metrics(self.residue_set, self.config.get("max_isotope_error", 1))
304 def setup_accelerator(self) -> Accelerator:
305 """Setup the accelerator."""
306 # TODO: How do we specify device without accelerator?
307 timeout = timedelta(seconds=self.config.get("timeout", 3600))
308 validate_and_configure_device(self.config)
309 accelerator = Accelerator(
310 cpu=self.config.get("force_cpu", False),
311 mixed_precision="fp16" if torch.cuda.is_available() and not self.config.get("force_cpu", False) else "no",
312 dataloader_config=DataLoaderConfiguration(split_batches=False),
313 kwargs_handlers=[InitProcessGroupKwargs(timeout=timeout)],
314 )
316 device = accelerator.device # Important, this forces ranks to choose a device.
318 if accelerator.num_processes > 1:
319 set_rank(accelerator.local_process_index)
321 if accelerator.is_main_process:
322 logger.info(f"Python version: {sys.version}")
323 logger.info(f"Torch version: {torch.__version__}")
324 logger.info(f"CUDA version: {torch.version.cuda}")
325 logger.info(f"Predicting with {accelerator.num_processes} devices")
326 logger.info(f"Per-device batch size: {self.config['batch_size']}")
328 logger.info(f"Using device: {device}")
330 return accelerator
332 def build_dataloader(self, test_dataset: Dataset) -> torch.utils.data.DataLoader:
333 """Setup the dataloaders."""
334 test_processor = self.setup_data_processor()
335 test_processor.add_metadata_columns(["prediction_id", "group"])
337 test_dataset = test_processor.process_dataset(test_dataset)
339 pin_memory = self.config.get("pin_memory", False)
340 if self.accelerator.device == torch.device("cpu") or self.config.get("mps", False):
341 pin_memory = False
343 test_dataloader = torch.utils.data.DataLoader(
344 test_dataset,
345 batch_size=self.config["batch_size"],
346 collate_fn=test_processor.collate_fn,
347 num_workers=self.config.get("num_workers", 8),
348 pin_memory=pin_memory,
349 prefetch_factor=self.config.get("prefetch_factor", 2),
350 drop_last=False,
351 )
352 return test_dataloader
354 def predict(self) -> pd.DataFrame:
355 """Predict the test dataset."""
356 all_predictions: dict[str, list] = defaultdict(list)
357 all_encoder_outputs: list[np.ndarray] = []
358 test_step = 0
360 logger.info("Predicting...")
361 inference_timer = Timer(self.steps_per_inference)
362 print_batch_size = True
363 for i, batch in enumerate(self.test_dataloader):
364 if print_batch_size:
365 logger.info(f"Batch {i} shape: {batch['spectra'].shape[0]}")
366 print_batch_size = False
368 # Implementation specific
369 with torch.no_grad(), self.accelerator.autocast():
370 batch_predictions = self.get_predictions(batch)
372 # Pass through prediction_id and group columns
373 # prediction_id is automatically cast to tensor
374 batch_predictions["prediction_id"] = [x.item() for x in batch["prediction_id"]]
375 batch_predictions["group"] = batch["group"]
377 # Some outputs are required from get_predictions
378 for k in PREDICTION_COLUMNS:
379 if k not in batch_predictions:
380 raise ValueError(f"Prediction column {k} not found in batch predictions.")
381 all_predictions[k].extend(batch_predictions[k])
383 if "encoder_output" in batch_predictions:
384 # Always ensure it is removed with pop even if we do not save it.
385 encoder_output = batch_predictions.pop("encoder_output")
386 if self.save_encoder_outputs:
387 all_encoder_outputs.extend(encoder_output)
389 # Additional prediction info
390 if self.config.get("save_all_predictions", False):
391 for k, v in batch_predictions.items():
392 if k in PREDICTION_COLUMNS:
393 continue
394 all_predictions[k].extend(v)
396 test_step += 1
397 inference_timer.step()
399 if (i + 1) % self.config.get("log_interval", 50) == 0 or (i + 1) == self.steps_per_inference:
400 logger.info(
401 f"[Batch {i + 1:05d}/{self.steps_per_inference:05d}] "
402 f"[{inference_timer.get_time_str()}/{inference_timer.get_eta_str()}] " # noqa: E501
403 f"{inference_timer.get_step_time_rate_str()}: "
404 )
406 logger.info(f"Time taken for {self.config.get('data_path', None)} is {inference_timer.get_delta():.1f} seconds")
408 logger.info("Prediction complete.")
409 self.accelerator.wait_for_everyone()
411 if self.accelerator.num_processes > 1:
412 logger.info("Gathering predictions from all processes...")
414 # Broadcast all predictions to all processes
415 for key, value in all_predictions.items():
416 all_predictions[key] = self.accelerator.gather_for_metrics(value, use_gather_object=True)
418 if self.save_encoder_outputs:
419 all_encoder_outputs = self.accelerator.gather_for_metrics(all_encoder_outputs, use_gather_object=True)
421 if self.accelerator.is_main_process:
422 pred_df = self.predictions_to_df(all_predictions)
423 pred_df = self.postprocess_predictions(pred_df)
425 results_dict = None
426 if not self.denovo:
427 logger.info("Calculating metrics...")
429 results_dict = self.calculate_metrics(pred_df)
431 self.save_predictions(pred_df, results_dict)
433 if self.save_encoder_outputs:
434 self.save_encoder_outputs_to_parquet(pred_df["spectrum_id"].tolist(), all_encoder_outputs)
435 else:
436 pred_df = None
438 return pred_df
440 def _tokens_to_string(self, tokens: list[str] | None) -> str:
441 """Convert a list of tokens to a ProForma compliant string."""
442 if tokens is None:
443 return ""
444 peptide = ""
445 if len(tokens) > 1 and not tokens[0][0].isalpha():
446 # Assume n-terminal
447 peptide = tokens[0] + "-"
448 tokens = tokens[1:]
449 return peptide + "".join(tokens)
451 def predictions_to_df(self, predictions: dict[str, list]) -> pd.DataFrame:
452 """Convert the predictions to a pandas DataFrame.
454 Args:
455 predictions: The predictions dictionary
457 Returns:
458 pd.DataFrame: The predictions dataframe
459 """
460 index_cols = self.config.get("index_columns", ["precursor_mz", "precursor_charge"])
461 index_cols = [x for x in index_cols if x in self.test_dataset.column_names]
462 index_cols.append("prediction_id")
463 index_df = self.test_dataset.to_pandas()[index_cols]
465 pred_df = pd.DataFrame(predictions)
466 # Drop duplicates caused by padding multiple processes
467 pred_df = pred_df.drop_duplicates(subset=["prediction_id"], keep="first")
469 pred_df = index_df.merge(pred_df, on="prediction_id", how="left")
471 # Some column processing
472 pred_df["predictions_tokenised"] = pred_df["predictions"].map(lambda x: ", ".join(x))
473 pred_df["predictions"] = pred_df["predictions"].map(self._tokens_to_string)
474 pred_df["targets"] = pred_df["targets"].map(self._tokens_to_string)
475 pred_df["delta_mass_ppm"] = pred_df.apply(
476 lambda row: np.min(np.abs(self.metrics.matches_precursor(row["predictions_tokenised"], row["precursor_mz"], row["precursor_charge"])[1])),
477 axis=1,
478 )
480 pred_df = pred_df.rename(
481 columns={
482 "predictions": self.prediction_col,
483 "predictions_tokenised": self.prediction_tokenised_col,
484 "prediction_log_probability": self.log_probs_col,
485 "prediction_token_log_probabilities": self.token_log_probs_col,
486 }
487 )
489 if self.denovo:
490 pred_df.drop(columns=["targets"], inplace=True)
492 return pred_df
494 def postprocess_predictions(self, pred_df: pd.DataFrame) -> pd.DataFrame:
495 """Postprocess the predictions.
497 Optionally, this can be used to modify the predictions, eg. ensembling.
498 By default, this does nothing.
500 Args:
501 pred_df: The predictions dataframe
503 Returns:
504 pd.DataFrame: The postprocessed predictions dataframe
505 """
506 return pred_df
508 def calculate_metrics(
509 self,
510 pred_df: pd.DataFrame,
511 ) -> dict[str, Any] | None:
512 """Calculate the metrics.
514 Args:
515 pred_df: The predictions dataframe
517 Returns:
518 dict[str, Any] | None: The results dictionary containing the metrics
519 """
520 predictions = pred_df[self.prediction_tokenised_col].copy()
521 targets = pred_df["targets"]
522 log_probs = pred_df[self.log_probs_col]
523 delta_mass_ppm = pred_df["delta_mass_ppm"]
525 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, predictions)
526 aa_er = self.metrics.compute_aa_er(targets, predictions)
527 auc = self.metrics.calc_auc(
528 targets,
529 predictions,
530 np.exp(log_probs),
531 )
533 logger.info("Performance:")
534 logger.info(f" aa_er {aa_er:.5f}")
535 logger.info(f" aa_prec {aa_prec:.5f}")
536 logger.info(f" aa_recall {aa_recall:.5f}")
537 logger.info(f" pep_prec {pep_prec:.5f}")
538 logger.info(f" pep_recall {pep_recall:.5f}")
539 logger.info(f" auc {auc:.5f}")
541 fdr = self.config.get("filter_fdr_threshold", None)
542 if fdr:
543 _, threshold = self.metrics.find_recall_at_fdr(
544 targets,
545 predictions,
546 np.exp(log_probs),
547 fdr=fdr,
548 )
549 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(
550 targets,
551 predictions,
552 np.exp(log_probs),
553 threshold=threshold,
554 )
555 logger.info(f"Performance at {fdr * 100:.1f}% FDR:")
556 logger.info(f" aa_prec {aa_prec:.5f}")
557 logger.info(f" aa_recall {aa_recall:.5f}")
558 logger.info(f" pep_prec {pep_prec:.5f}")
559 logger.info(f" pep_recall {pep_recall:.5f}")
560 logger.info(f" confidence {threshold:.5f}")
562 filter_precursor_ppm = self.config.get("filter_precursor_ppm", None)
563 if filter_precursor_ppm and delta_mass_ppm is not None:
564 idx = delta_mass_ppm < filter_precursor_ppm # type: ignore
565 logger.info(f"Performance with filtering at {filter_precursor_ppm} ppm delta mass:")
566 if np.sum(idx) > 0:
567 filtered_preds = pd.Series(predictions)
568 filtered_preds[~idx] = ""
569 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, filtered_preds)
570 logger.info(f" aa_prec {aa_prec:.5f}")
571 logger.info(f" aa_recall {aa_recall:.5f}")
572 logger.info(f" pep_prec {pep_prec:.5f}")
573 logger.info(f" pep_recall {pep_recall:.5f}")
574 logger.info(f"Rows filtered: {len(predictions) - np.sum(idx)} ({(len(predictions) - np.sum(idx)) / len(predictions) * 100:.2f}%)")
575 if np.sum(idx) < 1000:
576 logger.info(f"Metrics calculated on a small number of samples ({np.sum(idx)}), interpret with care!")
577 else:
578 logger.info("No predictions met criteria, skipping metrics.")
580 model_confidence_no_pred = self.config.get("filter_confidence", None)
581 if model_confidence_no_pred:
582 idx = np.exp(log_probs) > model_confidence_no_pred
583 logger.info(f"Performance with filtering confidence < {model_confidence_no_pred}")
584 if np.sum(idx) > 0:
585 filtered_preds = pd.Series(predictions)
586 filtered_preds[~idx] = ""
587 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(targets, filtered_preds)
588 logger.info(f" aa_prec {aa_prec:.5f}")
589 logger.info(f" aa_recall {aa_recall:.5f}")
590 logger.info(f" pep_prec {pep_prec:.5f}")
591 logger.info(f" pep_recall {pep_recall:.5f}")
592 logger.info(f"Rows filtered: {len(predictions) - np.sum(idx)} ({(len(predictions) - np.sum(idx)) / len(predictions) * 100:.2f}%)")
593 if np.sum(idx) < 1000:
594 logger.info(f"Metrics calculated on a small number of samples ({np.sum(idx)}), interpret with care!")
595 else:
596 logger.info("No predictions met criteria, skipping metrics.")
598 # Evaluate individual result files
599 if self.using_groups and not self.denovo:
600 logger.info("Evaluating individual result files.")
601 # TODO Handle better with pred_df
602 _preds = pd.Series(predictions)
603 _targs = pd.Series(targets)
604 _probs = pd.Series(log_probs)
606 # TODO Make this more generic
607 results_dict: dict[str, Any] = {
608 "run_name": self.config.get("run_name"),
609 "instanovo_model": self.config.get("instanovo_model"),
610 "num_beams": self.config.get("num_beams", 1),
611 "use_knapsack": self.config.get("use_knapsack", False),
612 }
613 for group in pred_df["group"].unique():
614 if group == "no_group":
615 continue
616 idx = pred_df["group"] == group
617 _group_preds = _preds[idx].reset_index(drop=True)
618 _group_targs = _targs[idx].reset_index(drop=True)
619 _group_probs = _probs[idx].reset_index(drop=True)
620 aa_prec, aa_recall, pep_recall, pep_prec = self.metrics.compute_precision_recall(_group_targs, _group_preds)
621 aa_er = self.metrics.compute_aa_er(_group_targs, _group_preds)
622 auc = self.metrics.calc_auc(_group_targs, _group_preds, _group_probs)
624 results_dict.update(
625 {
626 f"{group}_aa_prec": [aa_prec],
627 f"{group}_aa_recall": [aa_recall],
628 f"{group}_pep_recall": [pep_recall],
629 f"{group}_pep_prec": [pep_prec],
630 f"{group}_aa_er": [aa_er],
631 f"{group}_auc": [auc],
632 }
633 )
635 fdr = self.config.get("filter_fdr_threshold", None)
636 if fdr:
637 _, threshold = self.metrics.find_recall_at_fdr(_group_targs, _group_preds, np.exp(_group_probs), fdr=fdr)
638 _, _, pep_recall_at_fdr, _ = self.metrics.compute_precision_recall(
639 _group_targs,
640 _group_preds,
641 np.exp(_group_probs),
642 threshold=threshold,
643 )
645 results_dict.update(
646 {
647 f"{group}_pep_recall_at_{fdr:.3f}_fdr": [pep_recall_at_fdr],
648 }
649 )
650 return results_dict
651 return None
653 def save_predictions(self, pred_df: pd.DataFrame, results_dict: dict[str, list] | None = None) -> None:
654 """Save the predictions to a file.
656 Args:
657 pred_df: The predictions dataframe
658 results_dict: The results dictionary containing the metrics
659 """
660 # Save metrics to a file
661 if self.using_groups and not self.denovo and results_dict is not None:
662 result_path = self.config.get("result_file_path")
663 logger.info(f"Saving metrics to {result_path}.")
664 local_path = self.s3.get_local_path(result_path, missing_ok=True)
665 if local_path is not None and os.path.exists(local_path):
666 results_df = pd.read_csv(local_path)
667 results_df = pd.concat([results_df, pd.DataFrame(results_dict)], ignore_index=True, join="outer")
668 else:
669 results_df = pd.DataFrame(results_dict)
671 self.s3.upload_to_s3_wrapper(results_df.to_csv, result_path, index=False)
673 # Save individual result files per group
674 if self.using_groups and self._group_output is not None and pred_df is not None:
675 logger.info("Saving individual result files per group.")
676 for group in pred_df["group"].unique():
677 idx = pred_df["group"] == group
678 if self._group_output.get(group) is not None:
679 self.s3.upload_to_s3_wrapper(pred_df[idx].to_csv, self._group_output[group], index=False)
681 # Save output
682 if self.output_path is not None and pred_df is not None:
683 logger.info(f"Saving predictions to {self.output_path}.")
684 self.s3.upload_to_s3_wrapper(pred_df.to_csv, self.output_path, index=False)
685 logger.info(f"Predictions saved to {self.output_path}")
687 # Upload to Aichor
688 if S3FileHandler._aichor_enabled() and not self.output_path.startswith("s3://"):
689 self.s3.upload(self.output_path, S3FileHandler.convert_to_s3_output(self.output_path))
691 def save_encoder_outputs_to_parquet(self, spectrum_ids: list[str], encoder_outputs: list[np.ndarray]) -> None:
692 """Save the encoder outputs to a file.
694 Args:
695 encoder_outputs: The encoder outputs
696 spectrum_ids: The spectrum ids
697 """
698 if len(encoder_outputs) == 0:
699 logger.warning(f"No encoder outputs were returned by the decoder {type(self.decoder)}.")
700 logger.warning("Skipping encoder output saving.")
701 return
703 if self.encoder_output_path is not None:
704 encoder_outputs_fp32 = np.stack(encoder_outputs).astype(np.float32)
705 encoder_output_df = pl.DataFrame(
706 {"spectrum_id": spectrum_ids, **{f"spectrum_encoding_{i}": encoder_outputs_fp32[:, i] for i in range(encoder_outputs_fp32.shape[1])}}
707 )
709 logger.info(f"Saving encoder outputs to {self.encoder_output_path}.")
710 self.s3.upload_to_s3_wrapper(encoder_output_df.write_parquet, self.encoder_output_path)