Train
train
logger = ColorLog(console, __name__).logger
module-attribute
CONFIG_PATH = Path(__file__).parent.parent / 'configs'
module-attribute
TransformerTrainer(config: DictConfig)
Bases: AccelerateDeNovoTrainer
Trainer for the InstaNovo model.
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
instance-attribute
setup_model() -> nn.Module
Setup the model.
update_vocab(model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]
Update the vocabulary of the model.
setup_optimizer() -> torch.optim.Optimizer
Setup the optimizer.
setup_decoder() -> Decoder
Setup the decoder.
setup_data_processors() -> tuple[DataProcessor, DataProcessor]
Setup the datasets.
add_checkpoint_state() -> dict[str, Any]
Add checkpoint state.
save_model(is_best_checkpoint: bool = False) -> None
Save the model.
forward(batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]
Forward pass for the model to calculate loss.
get_predictions(batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]
Get the predictions for a batch.
main(config: DictConfig) -> None
Train the model.