Skip to content

Predictor

predictor

logger = ColorLog(console, __name__).logger module-attribute

AccelerateDeNovoPredictor(config: DictConfig)

Predictor class that uses the Accelerate library.

s3: S3FileHandler property

Get the S3 file handler.

RETURNS DESCRIPTION
S3FileHandler

The S3 file handler

TYPE: S3FileHandler

config = config instance-attribute

targets: list | None = None instance-attribute

output_path = self.config.get('output_path', None) instance-attribute

pred_df: pd.DataFrame | None = None instance-attribute

results_dict: dict | None = None instance-attribute

prediction_tokenised_col = self.config.get('prediction_tokenised_col', 'predictions_tokenised') instance-attribute

prediction_col = self.config.get('prediction_col', 'predictions') instance-attribute

log_probs_col = self.config.get('log_probs_col', 'log_probs') instance-attribute

token_log_probs_col = self.config.get('token_log_probs_col', 'token_log_probs') instance-attribute

save_encoder_outputs = config.get('save_encoder_outputs', False) instance-attribute

encoder_output_path = config.get('encoder_output_path', None) instance-attribute

encoder_output_reduction = config.get('encoder_output_reduction', 'mean') instance-attribute

accelerator = self.setup_accelerator() instance-attribute

denovo = self.config.get('denovo', False) instance-attribute

model = self.model.eval() instance-attribute

residue_set = self.model.residue_set instance-attribute

test_dataset = self.load_dataset() instance-attribute

test_dataloader = self.build_dataloader(self.test_dataset) instance-attribute

decoder = self.setup_decoder() instance-attribute

metrics = self.setup_metrics() instance-attribute

running_loss = None instance-attribute

steps_per_inference = len(self.test_dataloader) instance-attribute

load_model() -> Tuple[nn.Module, DictConfig] abstractmethod

Load the model.

setup_decoder() -> Decoder abstractmethod

Setup the decoder.

setup_data_processor() -> DataProcessor abstractmethod

Setup the data processor.

get_predictions(batch: Any) -> dict[str, Any] abstractmethod

Get the predictions for a batch.

postprocess_dataset(dataset: Dataset) -> Dataset

Postprocess the dataset.

load_dataset() -> Dataset

Load the test dataset.

RETURNS DESCRIPTION
Dataset

The test dataset

TYPE: Dataset

print_sample_batch() -> None

Print a sample batch of the training data.

setup_metrics() -> Metrics

Setup the metrics.

setup_accelerator() -> Accelerator

Setup the accelerator.

build_dataloader(test_dataset: Dataset) -> torch.utils.data.DataLoader

Setup the dataloaders.

predict() -> pd.DataFrame

Predict the test dataset.

predictions_to_df(predictions: dict[str, list]) -> pd.DataFrame

Convert the predictions to a pandas DataFrame.

PARAMETER DESCRIPTION
predictions

The predictions dictionary

TYPE: dict[str, list]

RETURNS DESCRIPTION
DataFrame

pd.DataFrame: The predictions dataframe

postprocess_predictions(pred_df: pd.DataFrame) -> pd.DataFrame

Postprocess the predictions.

Optionally, this can be used to modify the predictions, eg. ensembling. By default, this does nothing.

PARAMETER DESCRIPTION
pred_df

The predictions dataframe

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

pd.DataFrame: The postprocessed predictions dataframe

calculate_metrics(pred_df: pd.DataFrame) -> dict[str, Any] | None

Calculate the metrics.

PARAMETER DESCRIPTION
pred_df

The predictions dataframe

TYPE: DataFrame

RETURNS DESCRIPTION
dict[str, Any] | None

dict[str, Any] | None: The results dictionary containing the metrics

save_predictions(pred_df: pd.DataFrame, results_dict: dict[str, list] | None = None) -> None

Save the predictions to a file.

PARAMETER DESCRIPTION
pred_df

The predictions dataframe

TYPE: DataFrame

results_dict

The results dictionary containing the metrics

TYPE: dict[str, list] | None DEFAULT: None

save_encoder_outputs_to_parquet(spectrum_ids: list[str], encoder_outputs: list[np.ndarray]) -> None

Save the encoder outputs to a file.

PARAMETER DESCRIPTION
encoder_outputs

The encoder outputs

TYPE: list[ndarray]

spectrum_ids

The spectrum ids

TYPE: list[str]