Skip to content

Train

train

logger = ColorLog(console, __name__).logger module-attribute

CONFIG_PATH = Path(__file__).parent.parent / 'configs' module-attribute

TransformerTrainer(config: DictConfig)

Bases: AccelerateDeNovoTrainer

Trainer for the InstaNovo model.

loss_fn = nn.CrossEntropyLoss(ignore_index=0) instance-attribute

setup_model() -> nn.Module

Setup the model.

update_vocab(model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]

Update the vocabulary of the model.

setup_optimizer() -> torch.optim.Optimizer

Setup the optimizer.

setup_decoder() -> Decoder

Setup the decoder.

setup_data_processors() -> tuple[DataProcessor, DataProcessor]

Setup the datasets.

add_checkpoint_state() -> dict[str, Any]

Add checkpoint state.

save_model(is_best_checkpoint: bool = False) -> None

Save the model.

forward(batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]

Forward pass for the model to calculate loss.

get_predictions(batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]

Get the predictions for a batch.

main(config: DictConfig) -> None

Train the model.