Trainer
trainer
logger = ColorLog(console, __name__).logger
module-attribute
AccelerateDeNovoTrainer(config: DictConfig)
Trainer class that uses the Accelerate library.
run_id: str
property
Get the run ID.
| RETURNS | DESCRIPTION |
|---|---|
str
|
The run ID
TYPE:
|
s3: S3FileHandler
property
Get the S3 file handler.
| RETURNS | DESCRIPTION |
|---|---|
S3FileHandler
|
The S3 file handler
TYPE:
|
global_step: int
property
Get the current global training step.
This represents the total number of training steps across all epochs.
| RETURNS | DESCRIPTION |
|---|---|
int
|
The current global step number
TYPE:
|
epoch: int
property
Get the current training epoch.
This represents the current epoch number in the training process.
| RETURNS | DESCRIPTION |
|---|---|
int
|
The current epoch number
TYPE:
|
training_state: TrainingState
property
Get the training state.
config = config
instance-attribute
enable_verbose_logging = self.config.get('enable_verbose_logging', True)
instance-attribute
accelerator = self.setup_accelerator()
instance-attribute
residue_set = ResidueSet(residue_masses=(self.config.residues.get('residues')), residue_remapping=(self.config.dataset.get('residue_remapping', None)))
instance-attribute
model = self.setup_model()
instance-attribute
optimizer = self.setup_optimizer()
instance-attribute
lr_scheduler = self.setup_scheduler()
instance-attribute
decoder = self.setup_decoder()
instance-attribute
metrics = self.setup_metrics()
instance-attribute
running_loss = None
instance-attribute
total_steps = self.config.get('training_steps', 2500000)
instance-attribute
finetune_scheduler: FinetuneScheduler | None = FinetuneScheduler(self.model.state_dict(), self.config.get('finetune'))
instance-attribute
steps_per_validation = self.config.get('validation_interval', 100000)
instance-attribute
steps_per_checkpoint = self.config.get('checkpoint_interval', 100000)
instance-attribute
last_validation_metric = None
instance-attribute
best_checkpoint_metric = None
instance-attribute
setup_model() -> nn.Module
abstractmethod
Setup the model.
setup_optimizer() -> torch.optim.Optimizer
abstractmethod
Setup the optimizer.
setup_decoder() -> Decoder
abstractmethod
Setup the decoder.
setup_data_processors() -> tuple[DataProcessor, DataProcessor]
abstractmethod
Setup the data processor.
save_model(is_best_checkpoint: bool = False) -> None
abstractmethod
Save the model.
forward(batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]
abstractmethod
Forward pass for the model to calculate loss.
get_predictions(batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]
abstractmethod
Get the predictions for a batch.
convert_interval_to_steps(interval: float | int, steps_per_epoch: int) -> int
staticmethod
Convert an interval to steps.
| PARAMETER | DESCRIPTION |
|---|---|
interval
|
The interval to convert.
TYPE:
|
steps_per_epoch
|
The number of steps per epoch.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
int
|
The number of steps.
TYPE:
|
log_if_verbose(message: str, level: str = 'info') -> None
Log a message if verbose logging is enabled.
setup_metrics() -> Metrics
Setup the metrics.
setup_accelerator() -> Accelerator
Setup the accelerator.
build_dataloaders(train_dataset: Dataset, valid_dataset: Dataset) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]
Setup the dataloaders.
setup_scheduler() -> torch.optim.lr_scheduler.LRScheduler
Setup the learning rate scheduler.
| RETURNS | DESCRIPTION |
|---|---|
LRScheduler
|
torch.optim.lr_scheduler.LRScheduler: The learning rate scheduler |
setup_neptune() -> None
Setup the neptune.
setup_tensorboard() -> None
Setup the tensorboard.
load_datasets() -> tuple[Dataset, Dataset, int, int]
Load the training and validation datasets.
| RETURNS | DESCRIPTION |
|---|---|
tuple[Dataset, Dataset, int, int]
|
tuple[SpectrumDataFrame, SpectrumDataFrame]: The training and validation datasets |
print_sample_batch() -> None
Print a sample batch of the training data.
save_accelerator_state(is_best_checkpoint: bool = False) -> None
Save the accelerator state.
check_if_best_checkpoint() -> bool
Check if the last validation metric is the best metric.
load_accelerator_state() -> None
Load the accelerator state.
load_model_state() -> None
Load the model state.
update_model_state(model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]
Update the model state.
update_vocab(model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]
Update the vocabulary of the model.
train() -> None
Train the model.
prepare_batch(batch: Iterable[Any]) -> Any
Prepare a batch for training.
Manually move tensors to accelerator.device since we do not prepare our dataloaders with the accelerator.
| PARAMETER | DESCRIPTION |
|---|---|
batch
|
The batch to prepare.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
The prepared batch
TYPE:
|
train_epoch() -> None
Train the model for one epoch.
validate_epoch(num_sanity_steps: int | None = None, calculate_metrics: bool = True) -> None
Validate for one epoch.