Skip to content

Trainer

trainer

logger = ColorLog(console, __name__).logger module-attribute

AccelerateDeNovoTrainer(config: DictConfig)

Trainer class that uses the Accelerate library.

run_id: str property

Get the run ID.

RETURNS DESCRIPTION
str

The run ID

TYPE: str

s3: S3FileHandler property

Get the S3 file handler.

RETURNS DESCRIPTION
S3FileHandler

The S3 file handler

TYPE: S3FileHandler

global_step: int property

Get the current global training step.

This represents the total number of training steps across all epochs.

RETURNS DESCRIPTION
int

The current global step number

TYPE: int

epoch: int property

Get the current training epoch.

This represents the current epoch number in the training process.

RETURNS DESCRIPTION
int

The current epoch number

TYPE: int

training_state: TrainingState property

Get the training state.

config = config instance-attribute

enable_verbose_logging = self.config.get('enable_verbose_logging', True) instance-attribute

accelerator = self.setup_accelerator() instance-attribute

residue_set = ResidueSet(residue_masses=(self.config.residues.get('residues')), residue_remapping=(self.config.dataset.get('residue_remapping', None))) instance-attribute

model = self.setup_model() instance-attribute

optimizer = self.setup_optimizer() instance-attribute

lr_scheduler = self.setup_scheduler() instance-attribute

decoder = self.setup_decoder() instance-attribute

metrics = self.setup_metrics() instance-attribute

running_loss = None instance-attribute

total_steps = self.config.get('training_steps', 2500000) instance-attribute

finetune_scheduler: FinetuneScheduler | None = FinetuneScheduler(self.model.state_dict(), self.config.get('finetune')) instance-attribute

steps_per_validation = self.config.get('validation_interval', 100000) instance-attribute

steps_per_checkpoint = self.config.get('checkpoint_interval', 100000) instance-attribute

last_validation_metric = None instance-attribute

best_checkpoint_metric = None instance-attribute

setup_model() -> nn.Module abstractmethod

Setup the model.

setup_optimizer() -> torch.optim.Optimizer abstractmethod

Setup the optimizer.

setup_decoder() -> Decoder abstractmethod

Setup the decoder.

setup_data_processors() -> tuple[DataProcessor, DataProcessor] abstractmethod

Setup the data processor.

save_model(is_best_checkpoint: bool = False) -> None abstractmethod

Save the model.

forward(batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]] abstractmethod

Forward pass for the model to calculate loss.

get_predictions(batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]] abstractmethod

Get the predictions for a batch.

convert_interval_to_steps(interval: float | int, steps_per_epoch: int) -> int staticmethod

Convert an interval to steps.

PARAMETER DESCRIPTION
interval

The interval to convert.

TYPE: float | int

steps_per_epoch

The number of steps per epoch.

TYPE: int

RETURNS DESCRIPTION
int

The number of steps.

TYPE: int

log_if_verbose(message: str, level: str = 'info') -> None

Log a message if verbose logging is enabled.

setup_metrics() -> Metrics

Setup the metrics.

setup_accelerator() -> Accelerator

Setup the accelerator.

build_dataloaders(train_dataset: Dataset, valid_dataset: Dataset) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]

Setup the dataloaders.

setup_scheduler() -> torch.optim.lr_scheduler.LRScheduler

Setup the learning rate scheduler.

RETURNS DESCRIPTION
LRScheduler

torch.optim.lr_scheduler.LRScheduler: The learning rate scheduler

setup_neptune() -> None

Setup the neptune.

setup_tensorboard() -> None

Setup the tensorboard.

load_datasets() -> tuple[Dataset, Dataset, int, int]

Load the training and validation datasets.

RETURNS DESCRIPTION
tuple[Dataset, Dataset, int, int]

tuple[SpectrumDataFrame, SpectrumDataFrame]: The training and validation datasets

print_sample_batch() -> None

Print a sample batch of the training data.

save_accelerator_state(is_best_checkpoint: bool = False) -> None

Save the accelerator state.

check_if_best_checkpoint() -> bool

Check if the last validation metric is the best metric.

load_accelerator_state() -> None

Load the accelerator state.

load_model_state() -> None

Load the model state.

update_model_state(model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]

Update the model state.

update_vocab(model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]

Update the vocabulary of the model.

train() -> None

Train the model.

prepare_batch(batch: Iterable[Any]) -> Any

Prepare a batch for training.

Manually move tensors to accelerator.device since we do not prepare our dataloaders with the accelerator.

PARAMETER DESCRIPTION
batch

The batch to prepare.

TYPE: Iterable[Any]

RETURNS DESCRIPTION
Any

The prepared batch

TYPE: Any

train_epoch() -> None

Train the model for one epoch.

validate_epoch(num_sanity_steps: int | None = None, calculate_metrics: bool = True) -> None

Validate for one epoch.