Predictor
predictor
logger = ColorLog(console, __name__).logger
module-attribute
AccelerateDeNovoPredictor(config: DictConfig)
Predictor class that uses the Accelerate library.
s3: S3FileHandler
property
Get the S3 file handler.
| RETURNS | DESCRIPTION |
|---|---|
S3FileHandler
|
The S3 file handler
TYPE:
|
config = config
instance-attribute
targets: list | None = None
instance-attribute
output_path = self.config.get('output_path', None)
instance-attribute
pred_df: pd.DataFrame | None = None
instance-attribute
results_dict: dict | None = None
instance-attribute
prediction_tokenised_col = self.config.get('prediction_tokenised_col', 'predictions_tokenised')
instance-attribute
prediction_col = self.config.get('prediction_col', 'predictions')
instance-attribute
log_probs_col = self.config.get('log_probs_col', 'log_probs')
instance-attribute
token_log_probs_col = self.config.get('token_log_probs_col', 'token_log_probs')
instance-attribute
save_encoder_outputs = config.get('save_encoder_outputs', False)
instance-attribute
encoder_output_path = config.get('encoder_output_path', None)
instance-attribute
encoder_output_reduction = config.get('encoder_output_reduction', 'mean')
instance-attribute
accelerator = self.setup_accelerator()
instance-attribute
denovo = self.config.get('denovo', False)
instance-attribute
model = self.model.eval()
instance-attribute
residue_set = self.model.residue_set
instance-attribute
test_dataset = self.load_dataset()
instance-attribute
test_dataloader = self.build_dataloader(self.test_dataset)
instance-attribute
decoder = self.setup_decoder()
instance-attribute
metrics = self.setup_metrics()
instance-attribute
running_loss = None
instance-attribute
steps_per_inference = len(self.test_dataloader)
instance-attribute
load_model() -> Tuple[nn.Module, DictConfig]
abstractmethod
Load the model.
setup_decoder() -> Decoder
abstractmethod
Setup the decoder.
setup_data_processor() -> DataProcessor
abstractmethod
Setup the data processor.
get_predictions(batch: Any) -> dict[str, Any]
abstractmethod
Get the predictions for a batch.
postprocess_dataset(dataset: Dataset) -> Dataset
Postprocess the dataset.
load_dataset() -> Dataset
Load the test dataset.
| RETURNS | DESCRIPTION |
|---|---|
Dataset
|
The test dataset
TYPE:
|
print_sample_batch() -> None
Print a sample batch of the training data.
setup_metrics() -> Metrics
Setup the metrics.
setup_accelerator() -> Accelerator
Setup the accelerator.
build_dataloader(test_dataset: Dataset) -> torch.utils.data.DataLoader
Setup the dataloaders.
predict() -> pd.DataFrame
Predict the test dataset.
predictions_to_df(predictions: dict[str, list]) -> pd.DataFrame
Convert the predictions to a pandas DataFrame.
| PARAMETER | DESCRIPTION |
|---|---|
predictions
|
The predictions dictionary
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
DataFrame
|
pd.DataFrame: The predictions dataframe |
postprocess_predictions(pred_df: pd.DataFrame) -> pd.DataFrame
Postprocess the predictions.
Optionally, this can be used to modify the predictions, eg. ensembling. By default, this does nothing.
| PARAMETER | DESCRIPTION |
|---|---|
pred_df
|
The predictions dataframe
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
DataFrame
|
pd.DataFrame: The postprocessed predictions dataframe |
calculate_metrics(pred_df: pd.DataFrame) -> dict[str, Any] | None
Calculate the metrics.
| PARAMETER | DESCRIPTION |
|---|---|
pred_df
|
The predictions dataframe
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
dict[str, Any] | None
|
dict[str, Any] | None: The results dictionary containing the metrics |
save_predictions(pred_df: pd.DataFrame, results_dict: dict[str, list] | None = None) -> None
Save the predictions to a file.
| PARAMETER | DESCRIPTION |
|---|---|
pred_df
|
The predictions dataframe
TYPE:
|
results_dict
|
The results dictionary containing the metrics
TYPE:
|
save_encoder_outputs_to_parquet(spectrum_ids: list[str], encoder_outputs: list[np.ndarray]) -> None
Save the encoder outputs to a file.
| PARAMETER | DESCRIPTION |
|---|---|
encoder_outputs
|
The encoder outputs
TYPE:
|
spectrum_ids
|
The spectrum ids
TYPE:
|