Coverage for instanovo/common/trainer.py: 44%
643 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import datetime
4import logging
5import math
6import os
7import shutil
8import sys
9import traceback
10from abc import ABCMeta, abstractmethod
11from collections import Counter
12from datetime import timedelta
13from pathlib import Path
14from typing import Any, Dict, Iterable, List
16import neptune
17import numpy as np
18import pandas as pd
19import polars as pl
20import torch
21import torch.nn as nn
22from accelerate import Accelerator
23from accelerate.utils import DataLoaderConfiguration, InitProcessGroupKwargs, broadcast_object_list
24from datasets import Dataset, Value
25from datasets.utils.logging import disable_progress_bar
26from dotenv import load_dotenv
27from neptune.integrations.python_logger import NeptuneHandler
28from omegaconf import DictConfig, OmegaConf
29from sklearn.model_selection import train_test_split
30from torch.utils.tensorboard import SummaryWriter
32from instanovo.__init__ import console, set_rank
33from instanovo.common.dataset import DataProcessor
34from instanovo.common.scheduler import CosineWarmupScheduler, FinetuneScheduler, WarmupScheduler
35from instanovo.common.utils import (
36 NeptuneSummaryWriter,
37 Timer,
38 TrainingState,
39 _get_filepath_mapping,
40 _set_author_neptune_api_token,
41)
42from instanovo.constants import ANNOTATED_COLUMN, ANNOTATION_ERROR, SHUFFLE_BUFFER_SIZE, MSColumns
43from instanovo.inference import Decoder
44from instanovo.utils.colorlogging import ColorLog
45from instanovo.utils.data_handler import SpectrumDataFrame
46from instanovo.utils.device_handler import validate_and_configure_device
47from instanovo.utils.metrics import Metrics
48from instanovo.utils.residues import ResidueSet
49from instanovo.utils.s3 import S3FileHandler
51load_dotenv()
53# Automatic rank logger
54logger = ColorLog(console, __name__).logger
57class AccelerateDeNovoTrainer(metaclass=ABCMeta):
58 """Trainer class that uses the Accelerate library."""
60 @property
61 def run_id(self) -> str:
62 """Get the run ID.
64 Returns:
65 str: The run ID
66 """
67 return str(self._run_id)
69 @property
70 def s3(self) -> S3FileHandler:
71 """Get the S3 file handler.
73 Returns:
74 S3FileHandler: The S3 file handler
75 """
76 return self._s3
78 @property
79 def global_step(self) -> int:
80 """Get the current global training step.
82 This represents the total number of training steps across all epochs.
84 Returns:
85 int: The current global step number
86 """
87 return int(self._training_state.global_step)
89 @property
90 def epoch(self) -> int:
91 """Get the current training epoch.
93 This represents the current epoch number in the training process.
95 Returns:
96 int: The current epoch number
97 """
98 return int(self._training_state.epoch)
100 @property
101 def training_state(self) -> TrainingState:
102 """Get the training state."""
103 return self._training_state
105 def __init__(
106 self,
107 config: DictConfig,
108 ) -> None:
109 self.config = config
110 self.enable_verbose_logging = self.config.get("enable_verbose_logging", True)
111 if not self.config.get("enable_verbose_accelerate", True):
112 logging.getLogger("accelerate").setLevel(logging.WARNING)
114 # Hide progress bar from HF datasets
115 disable_progress_bar()
117 # Training state
118 # Keeps track of the global step and epoch
119 # Used for accelerate training state checkpointing
120 self._training_state = TrainingState()
122 self._run_id = self.config.get("run_name", "instanovo") + datetime.datetime.now().strftime("_%y_%m_%d_%H_%M")
124 self.accelerator = self.setup_accelerator()
126 self.log_if_verbose("Verbose logging enabled")
128 if self.accelerator.is_main_process:
129 logger.info(f"Config:\n{OmegaConf.to_yaml(self.config)}")
131 self.residue_set = ResidueSet(
132 residue_masses=self.config.residues.get("residues"),
133 residue_remapping=self.config.dataset.get("residue_remapping", None),
134 )
135 logger.info(f"Vocab: {self.residue_set.index_to_residue}")
137 # Initialise S3 file handler
138 self._s3: S3FileHandler = S3FileHandler(verbose=self.config.get("enable_verbose_s3", True))
140 self.train_dataset, self.valid_dataset, train_size, valid_size = self.load_datasets()
142 logger.info(f"Data loaded from {train_size:,} training samples and {valid_size:,} validation samples (unfiltered values)")
144 self.train_dataloader, self.valid_dataloader = self.build_dataloaders(self.train_dataset, self.valid_dataset)
145 logger.info("Data loaders built")
147 # Print sample batch
148 self.print_sample_batch()
150 logger.info("Setting up model...")
151 self.model = self.setup_model()
153 if self.accelerator.is_main_process:
154 logger.info(f"Model has {sum(p.numel() for p in self.model.parameters()):,d} parameters")
156 self.optimizer = self.setup_optimizer()
157 self.lr_scheduler = self.setup_scheduler()
159 self.decoder = self.setup_decoder()
160 self.metrics = self.setup_metrics()
162 # Optionally load a model state for fine-tuning
163 # Note: will be overwritten by the accelerator state if resuming
164 if self.config.get("resume_checkpoint_path", None) is not None:
165 self.load_model_state() # TODO check for loading on mps
167 # Prepare for accelerated training
168 (
169 self.model,
170 self.optimizer,
171 self.lr_scheduler,
172 self.train_dataloader,
173 self.valid_dataloader,
174 ) = self.accelerator.prepare(
175 self.model,
176 self.optimizer,
177 self.lr_scheduler,
178 self.train_dataloader,
179 self.valid_dataloader,
180 )
181 # Make sure the training state is checkpointed
182 self.accelerator.register_for_checkpointing(self._training_state)
184 # Optionally load states if resuming a training run
185 if self.config.get("resume_accelerator_state", None):
186 # Resuming from an existing run
187 self.load_accelerator_state()
189 # Setup logging
190 self.setup_neptune()
191 self.setup_tensorboard()
192 self._add_commit_message_to_monitoring_platform()
193 self._add_config_summary_to_monitoring_platform()
195 # Training control variables
196 self.running_loss = None
198 self.total_steps = self.config.get("training_steps", 2_500_000)
200 # Setup finetuning scheduler
201 if self.config.get("finetune", None):
202 self.finetune_scheduler: FinetuneScheduler | None = FinetuneScheduler(
203 self.model.state_dict(),
204 self.config.get("finetune"),
205 )
206 else:
207 self.finetune_scheduler = None
209 self.steps_per_validation = self.config.get("validation_interval", 100_000)
210 self.steps_per_checkpoint = self.config.get("checkpoint_interval", 100_000)
212 # Print training control variables
213 if self.accelerator.is_main_process:
214 steps_per_epoch = train_size // self.config["train_batch_size"]
215 logger.info("Training setup complete.")
216 logger.info(f" - Steps per validation: {self.steps_per_validation:,d} ")
217 logger.info(f" - Steps per checkpoint: {self.steps_per_checkpoint:,d} ")
218 logger.info(f" - Total training steps: {self.total_steps:,d}")
219 logger.info("Estimating steps per epoch based on unfiltered training set size:")
220 logger.info(f" - Estimated steps per epoch: {steps_per_epoch:,d}")
221 logger.info(f" - Estimated total epochs: {self.total_steps / steps_per_epoch:.1f}")
223 if self.total_steps < steps_per_epoch:
224 logger.warning("Total steps is less than estimated steps per epoch, this may result in less than one epoch during training")
226 if self.global_step > 0:
227 logger.info(f"Training will resume from epoch {self.epoch}, global_step {self.global_step}")
229 self.last_validation_metric = None
230 self.best_checkpoint_metric = None
232 # Final sync after setup
233 self.accelerator.wait_for_everyone()
235 @abstractmethod
236 def setup_model(self) -> nn.Module:
237 """Setup the model."""
238 ...
240 @abstractmethod
241 def setup_optimizer(self) -> torch.optim.Optimizer:
242 """Setup the optimizer."""
243 ...
245 @abstractmethod
246 def setup_decoder(self) -> Decoder:
247 """Setup the decoder."""
248 ...
250 @abstractmethod
251 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]:
252 """Setup the data processor."""
253 ...
255 @abstractmethod
256 def save_model(self, is_best_checkpoint: bool = False) -> None:
257 """Save the model."""
258 ...
260 @abstractmethod
261 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
262 """Forward pass for the model to calculate loss."""
263 ...
265 @abstractmethod
266 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]:
267 """Get the predictions for a batch."""
268 ...
270 @staticmethod
271 def convert_interval_to_steps(interval: float | int, steps_per_epoch: int) -> int:
272 """Convert an interval to steps.
274 Args:
275 interval (float | int): The interval to convert.
276 steps_per_epoch (int): The number of steps per epoch.
278 Returns:
279 int: The number of steps.
280 """
281 if isinstance(interval, float):
282 return int(interval * steps_per_epoch)
283 else:
284 raise ValueError(f"Invalid interval: {interval}")
286 def log_if_verbose(self, message: str, level: str = "info") -> None:
287 """Log a message if verbose logging is enabled."""
288 if self.enable_verbose_logging:
289 if level == "info":
290 logger.info(message)
291 elif level == "warning":
292 logger.warning(message)
293 elif level == "error":
294 logger.error(message)
295 elif level == "debug":
296 logger.debug(message)
297 else:
298 raise ValueError(f"Invalid level: {level}")
300 def setup_metrics(self) -> Metrics:
301 """Setup the metrics."""
302 return Metrics(self.residue_set, self.config.get("max_isotope_error", 1))
304 def setup_accelerator(self) -> Accelerator:
305 """Setup the accelerator."""
306 timeout = timedelta(seconds=self.config.get("timeout", 3600))
307 validate_and_configure_device(self.config)
308 accelerator = Accelerator(
309 cpu=self.config.get("force_cpu", False),
310 mixed_precision="fp16" if torch.cuda.is_available() and not self.config.get("force_cpu", False) else "no",
311 gradient_accumulation_steps=self.config.get("grad_accumulation", 1),
312 dataloader_config=DataLoaderConfiguration(split_batches=True),
313 kwargs_handlers=[InitProcessGroupKwargs(timeout=timeout)],
314 )
316 device = accelerator.device # Important, this forces ranks to choose a device.
318 if accelerator.num_processes > 1:
319 set_rank(accelerator.local_process_index)
321 if accelerator.is_main_process:
322 logger.info(f"Python version: {sys.version}")
323 logger.info(f"Torch version: {torch.__version__}")
324 logger.info(f"CUDA version: {torch.version.cuda}")
325 logger.info(f"Training with {accelerator.num_processes} devices")
326 logger.info(f"Per-device batch size: {self.config['train_batch_size']}")
327 logger.info(f"Gradient accumulation steps: {self.config['grad_accumulation']}")
328 effective_batch_size = self.config["train_batch_size"] * accelerator.num_processes * self.config["grad_accumulation"]
329 logger.info(f"Effective batch size: {effective_batch_size}")
331 logger.info(f"Using device: {device}")
333 return accelerator
335 def build_dataloaders(self, train_dataset: Dataset, valid_dataset: Dataset) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
336 """Setup the dataloaders."""
337 train_processor, valid_processor = self.setup_data_processors()
339 valid_processor.add_metadata_columns(["prediction_id"])
340 if self.using_validation_groups:
341 valid_processor.add_metadata_columns(["validation_group"])
343 if self.config.get("use_shuffle_buffer", True):
344 buffer_size = self.config.get("shuffle_buffer_size", SHUFFLE_BUFFER_SIZE)
345 train_dataset = train_dataset.shuffle(buffer_size=buffer_size, seed=42)
347 train_dataset = train_dataset.map(
348 train_processor.process_row,
349 )
350 valid_dataset = valid_processor.process_dataset(valid_dataset)
352 pin_memory = self.config.get("pin_memory", False)
353 if self.accelerator.device == torch.device("cpu") or self.config.get("mps", False):
354 pin_memory = False
355 # Scale batch size by number of processes when using split_batches
356 train_dataloader = torch.utils.data.DataLoader(
357 train_dataset,
358 batch_size=self.config["train_batch_size"] * self.accelerator.num_processes,
359 collate_fn=train_processor.collate_fn,
360 num_workers=self.config.get("num_workers", 8),
361 pin_memory=pin_memory,
362 prefetch_factor=self.config.get("prefetch_factor", None),
363 drop_last=True,
364 )
365 valid_dataloader = torch.utils.data.DataLoader(
366 valid_dataset,
367 batch_size=self.config["predict_batch_size"] * self.accelerator.num_processes,
368 collate_fn=valid_processor.collate_fn,
369 num_workers=self.config.get("num_workers", 8),
370 pin_memory=pin_memory,
371 prefetch_factor=self.config.get("prefetch_factor", None),
372 drop_last=False,
373 )
374 return train_dataloader, valid_dataloader
376 def setup_scheduler(self) -> torch.optim.lr_scheduler.LRScheduler:
377 """Setup the learning rate scheduler.
379 Returns:
380 torch.optim.lr_scheduler.LRScheduler: The learning rate scheduler
381 """
382 # Note: if split_batches is False, the scheduler will be called num_processes times
383 # in each optimizer step. Therefore, we need to scale the scheduler steps by num_processes.
384 # Default is split_batches is True
385 if self.config.get("lr_scheduler", "warmup") == "warmup":
386 warmup_steps = self.config.get("warmup_iters", 1000) # * num_processes
387 return WarmupScheduler(self.optimizer, warmup_steps)
388 elif self.config.get("lr_scheduler", None) == "cosine":
389 # Scale max_iters based on accumulation
390 # train_dataloader is already scaled by num_processes
391 max_iters = self.config.get("training_steps", 2_500_000) # * num_processes
392 warmup_steps = self.config.get("warmup_iters", 1000) # * num_processes
393 return CosineWarmupScheduler(self.optimizer, warmup_steps, max_iters)
394 else:
395 raise ValueError(f"Unknown lr_scheduler type '{self.config.get('lr_scheduler', None)}'")
397 def setup_neptune(self) -> None:
398 """Setup the neptune."""
399 if not self.accelerator.is_main_process:
400 self.neptune_run = None
401 return
403 if not self.config.get("use_neptune", True):
404 self.neptune_run = None
405 return
407 _set_author_neptune_api_token()
408 try:
409 self.neptune_run = neptune.init_run(
410 with_id=None,
411 name=self.run_id,
412 dependencies=str(Path(__file__).parent.parent.parent / "uv.lock"),
413 tags=OmegaConf.to_object(self.config.get("tags", [])),
414 )
415 self.neptune_run.assign({"config": OmegaConf.to_yaml(self.config)})
416 logger.addHandler(NeptuneHandler(run=self.neptune_run))
417 except Exception as e:
418 logger.warning(f"Failed to initialise neptune: {e}")
419 self.neptune_run = None
421 def setup_tensorboard(self) -> None:
422 """Setup the tensorboard."""
423 if not self.accelerator.is_main_process:
424 self.sw = None
425 return
427 if S3FileHandler.register_tb():
428 logs_path = os.environ["AICHOR_LOGS_PATH"]
429 else:
430 logs_path = self.config.get("tb_summarywriter", "runs") + self.run_id
432 if self.neptune_run is not None:
433 self.sw = NeptuneSummaryWriter(logs_path, self.neptune_run)
434 else:
435 self.sw = SummaryWriter(logs_path)
436 logger.info(f"TensorBoard logs will be saved to {logs_path}")
438 @property
439 def _is_main_process_on_aichor(self) -> bool:
440 """Return True if monitoring logging is configured and this is the main process."""
441 return ("AICHOR_LOGS_PATH" in os.environ) and self.accelerator.is_main_process
443 def _add_commit_message_to_monitoring_platform(self, commit_id_length: int = 7) -> None:
444 """Add the git commit message to the monitoring platform."""
445 if not self._is_main_process_on_aichor:
446 logger.debug("Skipping config summary upload to Neptune: 'AICHOR_LOGS_PATH' not set or current process is not the main AICHOR process")
447 return
449 try:
450 # Remove 'exp:' prefix if present in the commit message and also remove the space after it.
451 git_commit_msg = os.environ["VCS_COMMIT_MESSAGE"].removeprefix("exp:").removeprefix(" ")
452 commit_short_hash = os.environ["VCS_SHA"][:commit_id_length]
453 self.sw.add_text( # type: ignore[union-attr]
454 "git/commit_message", f"{git_commit_msg} ({commit_short_hash})"
455 )
456 except (AttributeError, KeyError) as exc:
457 logger.warning("Failed to write config summary to the monitoring plateform", exc_info=exc)
459 def _add_config_summary_to_monitoring_platform(self) -> None:
460 """Add the config summary to the monitoring platform."""
461 if not self._is_main_process_on_aichor:
462 logger.debug("Skipping config summary upload to Neptune: 'AICHOR_LOGS_PATH' not set or current process is not the main AICHOR process")
463 return
464 # https://github.com/pytorch/pytorch/blob/daca611465c93ac6b8147e6b7070ce2b4254cfc5/torch/utils/tensorboard/summary.py#L244 # noqa
465 self.sw.add_hparams( # type: ignore[union-attr]
466 {k: v for k, v in self.config.items() if isinstance(v, (int, float, str))}, {}
467 )
469 def load_datasets(self) -> tuple[Dataset, Dataset, int, int]:
470 """Load the training and validation datasets.
472 Returns:
473 tuple[SpectrumDataFrame, SpectrumDataFrame]:
474 The training and validation datasets
475 """
476 validation_group_mapping = None
477 dataset_config = self.config.get("dataset", {})
478 try:
479 logger.info("Loading training dataset...")
480 train_sdf = SpectrumDataFrame.load(
481 source=dataset_config.get("train_path"),
482 source_type=dataset_config.get("source_type", "default"),
483 lazy=dataset_config.get("lazy_loading", True),
484 is_annotated=True,
485 shuffle=True,
486 partition=dataset_config.get("train_partition", None),
487 column_mapping=dataset_config.get("column_remapping", None),
488 max_shard_size=dataset_config.get("max_shard_size", 100_000),
489 preshuffle_across_shards=dataset_config.get("preshuffle_shards", False),
490 verbose=dataset_config.get("verbose_loading", True),
491 )
493 valid_path = dataset_config.get("valid_path", None)
494 if valid_path is not None:
495 if OmegaConf.is_dict(valid_path):
496 logger.info("Found grouped validation datasets.")
497 validation_group_mapping = _get_filepath_mapping(valid_path)
498 _valid_path = list(valid_path.values())
499 else:
500 _valid_path = valid_path
501 else:
502 _valid_path = dataset_config.get("train_path")
504 logger.info("Loading validation dataset...")
505 valid_sdf = SpectrumDataFrame.load(
506 _valid_path,
507 lazy=dataset_config.get("lazy_loading", True),
508 is_annotated=True,
509 shuffle=False,
510 partition=dataset_config.get("valid_partition", None),
511 column_mapping=dataset_config.get("column_remapping", None),
512 max_shard_size=dataset_config.get("max_shard_size", 100_000),
513 add_source_file_column=True, # used to track validation groups
514 verbose=dataset_config.get("verbose_loading", True),
515 )
516 except ValueError as e:
517 # More descriptive error message in predict mode.
518 if str(e) == ANNOTATION_ERROR:
519 raise ValueError("The sequence column is missing annotations, are you trying to run de novo prediction? Add the --denovo flag") from e
520 else:
521 raise
523 # Split data if needed
524 if dataset_config.get("valid_path", None) is None:
525 logger.info("Validation path not specified, generating from training set.")
526 sequences = list(train_sdf.get_unique_sequences())
527 sequences = sorted({DataProcessor.remove_modifications(x) for x in sequences})
528 train_unique, valid_unique = train_test_split(
529 sequences,
530 test_size=dataset_config.get("valid_subset_of_train"),
531 random_state=42,
532 )
533 train_unique = set(train_unique)
534 valid_unique = set(valid_unique)
536 train_sdf.filter_rows(lambda row: DataProcessor.remove_modifications(row[ANNOTATED_COLUMN]) in train_unique)
537 valid_sdf.filter_rows(lambda row: DataProcessor.remove_modifications(row[ANNOTATED_COLUMN]) in valid_unique)
539 # Save splits
540 if self.accelerator.is_main_process:
541 split_path = os.path.join(self.config.get("model_save_folder_path", "./checkpoints"), "splits.csv")
542 os.makedirs(os.path.dirname(split_path), exist_ok=True)
543 splits_df = pd.DataFrame(
544 {
545 ANNOTATED_COLUMN: list(train_unique) + list(valid_unique),
546 "split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique),
547 }
548 )
549 self.s3.upload_to_s3_wrapper(splits_df.to_csv, split_path, index=False)
550 logger.info(f"Data splits saved to {split_path}")
552 train_ds = train_sdf.to_dataset(force_unified_schema=True)
553 valid_ds = valid_sdf.to_dataset(in_memory=True)
555 # # Sample subsets if needed
556 valid_subset = self.config.get("valid_subset", 1.0)
557 if valid_subset < 1.0:
558 valid_ds = valid_ds.train_test_split(test_size=valid_subset, seed=42)["test"]
560 # Check residues
561 if self.config.get("perform_data_checks", True):
562 logger.info(f"Checking for unknown residues in {len(train_sdf) + len(valid_sdf):,d} rows.")
563 supported_residues = set(self.residue_set.vocab)
564 supported_residues.update(set(self.residue_set.residue_remapping.keys()))
565 data_residues = set()
566 data_residues.update(train_sdf.get_vocabulary(self.residue_set.tokenize))
567 data_residues.update(valid_sdf.get_vocabulary(self.residue_set.tokenize))
568 if len(data_residues - supported_residues) > 0:
569 logger.warning(f"Found {len(data_residues - supported_residues):,d} unsupported residues! These rows will be dropped.")
570 self.log_if_verbose(f"New residues found: \n{data_residues - supported_residues}")
571 self.log_if_verbose(f"Residues supported: \n{supported_residues}")
573 train_ds = train_ds.filter(
574 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN])))
575 )
576 valid_ds = valid_ds.filter(
577 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN])))
578 )
580 logger.info("Checking charge values...")
581 # Check charge values
582 precursor_charge_col = MSColumns.PRECURSOR_CHARGE.value
584 if not train_sdf.check_values(1, self.config.get("max_charge", 10), precursor_charge_col):
585 logger.warning("Found charge values out of range in training set. These rows will be dropped.")
587 train_ds = train_ds.filter(
588 lambda row: (row[precursor_charge_col] <= self.config.get("max_charge", 10)) and (row[precursor_charge_col] > 0)
589 )
591 if not valid_sdf.check_values(1, self.config.get("max_charge", 10), precursor_charge_col):
592 logger.warning("Found charge values out of range in validation set. These rows will be dropped.")
593 valid_ds = valid_ds.filter(
594 lambda row: (row[precursor_charge_col] <= self.config.get("max_charge", 10)) and (row[precursor_charge_col] > 0)
595 )
597 # Create validation groups
598 # Initialize validation groups if needed
599 if validation_group_mapping is not None:
600 logger.info("Computing validation groups.")
601 validation_groups = [validation_group_mapping.get(row.get("source_file"), "no_group") for row in valid_ds]
602 valid_ds = valid_ds.add_column("validation_group", validation_groups)
604 logger.info("Sequences per validation group:")
605 group_counts = Counter(validation_groups)
606 for group, count in group_counts.items():
607 logger.info(f" - {group}: {count:,d}")
609 self.using_validation_groups = True
610 else:
611 self.using_validation_groups = False
613 # Force add a unique prediction_id column
614 # This will be used to order predictions and remove duplicates
615 valid_ds = valid_ds.add_column("prediction_id", np.arange(len(valid_ds)), feature=Value("int32"))
617 # Keep track of the train_sdf directory so it isn't garbage collected
618 self._train_sdf = train_sdf
620 return train_ds, valid_ds, len(train_sdf), len(valid_sdf)
622 def print_sample_batch(self) -> None:
623 """Print a sample batch of the training data."""
624 if self.accelerator.is_main_process:
625 # sample_batch = next(iter(self.train_dataloader))
626 sample_batch = next(iter(self.train_dataloader))
627 logger.info("Sample batch:")
628 for key, value in sample_batch.items():
629 if isinstance(value, torch.Tensor):
630 value_shape = value.shape
631 value_type = value.dtype
632 else:
633 value_shape = len(value)
634 value_type = type(value)
636 logger.info(f" - {key}: {value_type}, {value_shape}")
638 def save_accelerator_state(self, is_best_checkpoint: bool = False) -> None:
639 """Save the accelerator state."""
640 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints")
642 if self.config.get("keep_accelerator_every_interval", False):
643 checkpoint_path = os.path.join(
644 checkpoint_dir,
645 "accelerator_state",
646 f"epoch_{self.epoch}_step_{self.global_step + 1}",
647 )
648 else:
649 checkpoint_path = os.path.join(checkpoint_dir, "accelerator_state", "latest")
650 if self.accelerator.is_main_process and Path(checkpoint_path).exists() and Path(checkpoint_path).is_dir():
651 shutil.rmtree(checkpoint_path)
653 if self.accelerator.is_main_process:
654 os.makedirs(checkpoint_path, exist_ok=True)
656 self.accelerator.save_state(checkpoint_path)
658 logger.info(f"Saved accelerator state to {checkpoint_path}")
660 if self.accelerator.is_main_process and S3FileHandler._aichor_enabled():
661 for file in os.listdir(checkpoint_path):
662 self.s3.upload(
663 os.path.join(checkpoint_path, file),
664 S3FileHandler.convert_to_s3_output(os.path.join(checkpoint_path, file)),
665 )
667 # Save best checkpoint and upload to S3
668 if is_best_checkpoint and self.accelerator.is_main_process:
669 best_checkpoint_path = os.path.join(checkpoint_dir, "accelerator_state", "best")
670 if Path(best_checkpoint_path).exists() and Path(best_checkpoint_path).is_dir():
671 shutil.rmtree(best_checkpoint_path)
673 os.makedirs(best_checkpoint_path, exist_ok=True)
675 for file in os.listdir(checkpoint_path):
676 full_file = os.path.join(checkpoint_path, file)
677 best_file = os.path.join(best_checkpoint_path, file)
678 shutil.copy(full_file, best_file)
679 if S3FileHandler._aichor_enabled():
680 self.s3.upload(
681 full_file,
682 S3FileHandler.convert_to_s3_output(best_file),
683 )
685 def check_if_best_checkpoint(self) -> bool:
686 """Check if the last validation metric is the best metric."""
687 if self.config.get("checkpoint_metric", None) is None:
688 return False
690 if self.best_checkpoint_metric is None:
691 self.best_checkpoint_metric = self.last_validation_metric
692 return True
694 if self.config.get("checkpoint_metric_mode", "min") == "min":
695 is_best = self.last_validation_metric <= self.best_checkpoint_metric
696 elif self.config.get("checkpoint_metric_mode", "min") == "max":
697 is_best = self.last_validation_metric >= self.best_checkpoint_metric
698 else:
699 raise ValueError(f"Unknown checkpoint metric mode: {self.config.get('checkpoint_metric_mode', 'min')}")
701 if is_best:
702 self.best_checkpoint_metric = self.last_validation_metric
704 return is_best
706 def load_accelerator_state(self) -> None:
707 """Load the accelerator state."""
708 checkpoint_path = self.config.get("resume_accelerator_state", None)
709 if checkpoint_path is None:
710 return
712 if not os.path.isdir(checkpoint_path) and not checkpoint_path.startswith("s3://"):
713 raise ValueError(f"Accelerator state should be a directory of state files, got {checkpoint_path}")
715 if S3FileHandler._aichor_enabled() and checkpoint_path.startswith("s3://"):
716 # raise NotImplementedError("Loading accelerator state from S3 is not implemented.")
718 if self.accelerator.is_main_process:
719 local_path = os.path.join(self.s3.temp_dir.name, "accelerator_state")
720 os.makedirs(local_path, exist_ok=True)
721 logger.info(f"Downloading checkpoint files from {checkpoint_path} to {local_path}")
723 # Download all files from the checkpoint folder
724 checkpoint_files = self.s3.listdir(checkpoint_path)
725 logger.info(f"Found {len(checkpoint_files)} files")
726 for file in checkpoint_files:
727 if file.endswith("/"): # Skip subdirectories
728 continue
729 local_file = os.path.join(local_path, os.path.basename(file))
730 self.s3.download(f"s3://{file}", local_file)
731 else:
732 local_path = None
734 checkpoint_path = broadcast_object_list([local_path])[0]
735 logger.info(f"Received checkpoint path: {checkpoint_path}")
737 assert checkpoint_path is not None, "Failed to broadcast accelerator state across ranks"
739 # Add safe globals
740 torch.serialization.add_safe_globals(
741 [
742 np._core.multiarray.scalar,
743 np.dtypes.Float64DType,
744 ]
745 )
747 self.accelerator.load_state(checkpoint_path)
748 logger.info(f"Loaded accelerator state from {checkpoint_path}")
750 def load_model_state(self) -> None:
751 """Load the model state."""
752 checkpoint_path = self.config.get("resume_checkpoint_path", None)
753 if checkpoint_path is None:
754 return
756 if os.path.isdir(checkpoint_path) and not checkpoint_path.startswith("s3://"):
757 raise ValueError(f"Checkpoint path should be a file, got {checkpoint_path}")
759 if self.accelerator.is_main_process:
760 logger.info(f"Resuming model state from {checkpoint_path}")
761 local_path = self.s3.get_local_path(checkpoint_path)
762 else:
763 local_path = None
765 local_path = broadcast_object_list([local_path])[0]
767 assert local_path is not None, "Failed to broadcast model state across ranks"
769 # TODO: Switch to model.load(), implement model schema
770 model_data = torch.load(local_path, weights_only=False, map_location="cpu")
771 # TODO: Remove, only use state_dict
772 if "model" in model_data:
773 model_state = model_data["model"]
774 else:
775 model_state = model_data["state_dict"]
777 # Check residues
778 if "residues" in model_data:
779 model_residues = dict(model_data["residues"].get("residues", {}))
780 else:
781 # Legacy format
782 model_residues = dict(model_data["config"]["residues"])
784 current_residues = self.config.residues.get("residues")
785 if model_residues != current_residues:
786 logger.warning(
787 f"Checkpoint residues do not match current residues.\nCheckpoint residues: {model_residues}\nCurrent residues: {current_residues}"
788 )
789 logger.warning("Updating model state to match current residues.")
790 model_state = self.update_vocab(model_state)
792 model_state = self.update_model_state(model_state, model_data["config"])
793 self.model.load_state_dict(model_state, strict=False)
794 logger.info(f"Loaded model state from {local_path}")
796 def update_model_state(self, model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]:
797 """Update the model state."""
798 return model_state
800 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
801 """Update the vocabulary of the model."""
802 # This should call `self._update_vocab` based on model implementation.
803 raise NotImplementedError("Updating vocabulary is not implemented for the base trainer.")
805 def _update_vocab(
806 self,
807 model_state: dict[str, torch.Tensor],
808 target_layers: list[str],
809 resolution: str = "delete",
810 ) -> dict[str, torch.Tensor]:
811 """Update the target heads of the model."""
812 target_vocab_size = len(self.residue_set)
813 current_model_state = self.model.state_dict()
814 hidden_size = self.config.model.get("dim_model", 768)
816 for layer in target_layers:
817 if layer not in current_model_state:
818 logger.warning(f"Layer {layer} not found in current model state.")
819 continue
820 tmp = torch.normal(
821 mean=0,
822 std=1.0 / np.sqrt(hidden_size),
823 size=current_model_state[layer].shape,
824 dtype=current_model_state[layer].dtype,
825 )
826 if "bias" in layer:
827 # initialise bias to zeros
828 tmp = torch.zeros_like(tmp)
830 if resolution == "delete":
831 del model_state[layer]
832 elif resolution == "random":
833 model_state[layer] = tmp
834 elif resolution == "partial":
835 tmp[:target_vocab_size] = model_state[layer][: min(tmp.shape[0], target_vocab_size)]
836 model_state[layer] = tmp
837 else:
838 raise ValueError(f"Unknown residue_conflict_resolution type '{resolution}'")
839 return model_state
841 def train(self) -> None:
842 """Train the model."""
843 num_sanity_steps = self.config.get("num_sanity_val_steps", 0)
844 if num_sanity_steps > 0:
845 logger.info(f"Running sanity validation for {num_sanity_steps} steps...")
846 self.validate_epoch(num_sanity_steps=num_sanity_steps, calculate_metrics=False)
847 logger.info("Sanity validation complete.")
849 if self.config.get("validate_before_training", False):
850 logger.info("Running pre-validation...")
851 self.validate_epoch()
852 logger.info("Pre-validation complete.")
854 self.train_timer = Timer(self.total_steps)
855 is_first_epoch = True
856 logger.info("Starting training...")
857 while self.global_step < self.total_steps:
858 self.train_epoch()
859 self.training_state.step_epoch()
860 if self.accelerator.is_main_process and is_first_epoch:
861 is_first_epoch = False
862 logger.info("First epoch complete:")
863 logger.info(f"- Actual steps per epoch: {self.global_step}")
864 logger.info(f"- Actual total epochs: {self.total_steps / self.global_step:.1f}")
866 logger.info("Training complete.")
868 def prepare_batch(self, batch: Iterable[Any]) -> Any:
869 """Prepare a batch for training.
871 Manually move tensors to accelerator.device since we do not
872 prepare our dataloaders with the accelerator.
874 Args:
875 batch (Iterable[Any]): The batch to prepare.
877 Returns:
878 Any: The prepared batch
879 """
880 if isinstance(batch, dict):
881 return {k: v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
882 elif isinstance(batch, (list, tuple)):
883 return [v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v for v in batch]
884 else:
885 raise ValueError(f"Unsupported batch type: {type(batch)}")
887 def train_epoch(self) -> None:
888 """Train the model for one epoch."""
889 total_loss = 0
891 self.model.train()
892 self.optimizer.zero_grad()
893 self.running_loss = None
895 epoch_timer = Timer()
897 print_batch_size = True
898 for batch_count, batch in enumerate(self.train_dataloader):
899 if print_batch_size:
900 # Confirm batch size during debugging
901 self.log_if_verbose(f"Batch {batch_count} shape: {batch['spectra'].shape[0]}")
902 print_batch_size = False
904 with self.accelerator.accumulate(self.model):
905 # Forward pass
906 loss, loss_components = self.forward(batch)
908 # Check for NaN/Inf in loss immediately after forward pass
909 loss_value = loss.item()
910 if math.isnan(loss_value) or math.isinf(loss_value):
911 error_msg = (
912 f"Invalid loss value detected: {loss_value} (NaN: {math.isnan(loss_value)}, Inf: {math.isinf(loss_value)}). "
913 f"This occurred at step {self.global_step + 1}, epoch {self.epoch}, batch {batch_count}. "
914 f"This indicates a serious training problem (e.g., exploding gradients, division by zero, numerical instability). "
915 f"Stopping training to prevent further issues.\n\n"
916 f"Loss components: {[f'{k}={v.item()}' for k, v in loss_components.items()]}\n\n"
917 f"Traceback showing where the loss was computed:\n"
918 )
919 stack_trace = traceback.format_stack()
920 # Show frames from the forward pass
921 relevant_frames = stack_trace[:-3][-8:] # Show more frames to see the forward pass
922 error_msg += "".join(relevant_frames)
923 raise ValueError(error_msg)
925 # Backward pass
926 self.accelerator.backward(loss)
928 # Update weights
929 if self.accelerator.sync_gradients:
930 self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.get("gradient_clip_val", 10.0))
931 self.optimizer.step()
932 self.lr_scheduler.step()
933 self.optimizer.zero_grad()
935 # Update timer
936 self.train_timer.step()
938 # Update running loss
939 # Exponentially weighted moving average to smooth noisy batch losses
940 if self.running_loss is None:
941 self.running_loss = loss.item()
942 else:
943 self.running_loss = 0.99 * self.running_loss + 0.01 * loss.item()
945 total_loss += loss.item()
947 # Log progress
948 if (self.global_step + 1) % int(self.config.get("console_logging_steps", 2000)) == 0:
949 lr = self.lr_scheduler.get_last_lr()[0]
951 logger.info(
952 f"[TRAIN] "
953 f"[Epoch {self.epoch:02d}] "
954 f"[Step {self.global_step + 1:06d}/{self.total_steps:06d}] "
955 f"[{self.train_timer.get_time_str()}/{self.train_timer.get_eta_str()}, "
956 f"{self.train_timer.get_step_time_rate_str()}]: "
957 f"train_loss_raw={loss.item():.4f}, "
958 f"running_loss={self.running_loss:.4f}, LR={lr:.6f}"
959 )
961 # Log to tensorboard
962 if (
963 self.accelerator.is_main_process
964 and self.sw is not None
965 and (self.global_step + 1) % int(self.config.get("tensorboard_logging_steps", 500)) == 0
966 ):
967 lr = self.lr_scheduler.get_last_lr()[0]
968 self.sw.add_scalar("train/loss_raw", loss.item(), self.global_step + 1)
969 if self.running_loss is not None:
970 self.sw.add_scalar("train/loss_smooth", self.running_loss, self.global_step + 1)
971 for k, v in loss_components.items():
972 if k == "loss":
973 continue
974 self.sw.add_scalar(f"train/{k}", v.item(), self.global_step + 1)
975 self.sw.add_scalar("optim/lr", lr, self.global_step + 1)
976 self.sw.add_scalar("optim/epoch", self.epoch, self.global_step + 1)
978 if (self.global_step + 1) % self.steps_per_validation == 0:
979 self.model.eval()
980 self.validate_epoch()
981 logger.info("Validation complete, resuming training...")
982 self.model.train()
984 if (self.global_step + 1) % self.steps_per_checkpoint == 0:
985 is_best_checkpoint = self.check_if_best_checkpoint()
986 self.save_model(is_best_checkpoint)
987 if self.config.get("save_accelerator_state", False):
988 self.save_accelerator_state(is_best_checkpoint)
990 self.training_state.step()
992 # Update finetuning scheduler
993 if self.finetune_scheduler is not None:
994 self.finetune_scheduler.step(self.global_step)
996 if self.global_step >= self.total_steps:
997 break
999 # Epoch complete
1000 self.accelerator.wait_for_everyone()
1002 epoch_timer.step()
1004 # Gather losses from all devices
1005 gathered_losses = self.accelerator.gather_for_metrics(torch.tensor(total_loss, device=self.accelerator.device))
1006 gathered_num_batches = self.accelerator.gather_for_metrics(torch.tensor(batch_count, device=self.accelerator.device))
1008 if self.accelerator.is_main_process and self.sw is not None:
1009 # Sum the losses and batch counts from all devices
1010 total_loss_all_devices = gathered_losses.sum().item()
1011 total_batches_all_devices = gathered_num_batches.sum().item()
1012 avg_loss = total_loss_all_devices / total_batches_all_devices
1014 self.sw.add_scalar("eval/train_loss", avg_loss, self.epoch)
1016 logger.info(f"[TRAIN] [Epoch {self.epoch:02d}] Epoch complete, total time {epoch_timer.get_time_str()}")
1018 def validate_epoch(self, num_sanity_steps: int | None = None, calculate_metrics: bool = True) -> None:
1019 """Validate for one epoch."""
1020 if self.valid_dataloader is None:
1021 return
1023 if self.accelerator.is_main_process:
1024 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] Starting validation.")
1026 valid_epoch_step = 0
1027 valid_predictions: List[List[str] | str] = []
1028 valid_targets: List[List[str] | str] = []
1029 valid_groups: List[str] = []
1030 valid_prediction_ids: List[int] = []
1032 valid_metrics: Dict[str, List[float]] = {x: [] for x in ["valid_loss", "aa_er", "aa_prec", "aa_recall", "pep_recall"]}
1034 num_batches = len(self.valid_dataloader)
1036 valid_timer = Timer(num_batches)
1038 for batch_idx, batch in enumerate(self.valid_dataloader):
1039 if num_sanity_steps is not None and batch_idx >= num_sanity_steps:
1040 break
1042 with torch.no_grad(), self.accelerator.autocast():
1043 # Loss calculation
1044 loss, _ = self.forward(batch)
1045 # Get actual predictions
1046 y, targets = self.get_predictions(batch)
1048 valid_predictions.extend(y)
1049 valid_targets.extend(targets)
1050 valid_prediction_ids.extend([x.item() for x in batch["prediction_id"]])
1052 # Store validation groups if available
1053 if self.using_validation_groups:
1054 valid_groups.extend(batch["validation_group"])
1056 # Update metrics
1057 if self.metrics is not None:
1058 aa_prec, aa_recall, pep_recall, _ = self.metrics.compute_precision_recall(targets, y)
1059 aa_er = self.metrics.compute_aa_er(targets, y)
1061 valid_metrics["valid_loss"].append(loss.item())
1062 valid_metrics["aa_er"].append(aa_er)
1063 valid_metrics["aa_prec"].append(aa_prec)
1064 valid_metrics["aa_recall"].append(aa_recall)
1065 valid_metrics["pep_recall"].append(pep_recall)
1067 valid_epoch_step += 1
1069 valid_timer.step()
1071 # Log progress
1072 if (valid_epoch_step + 1) % int(self.config.get("console_logging_steps", 2000)) == 0:
1073 epoch_step = valid_epoch_step % num_batches
1075 logger.info(
1076 f"[VALIDATION] "
1077 f"[Epoch {self.epoch:02d}] "
1078 f"[Step {self.global_step + 1:06d}] "
1079 f"[Batch {epoch_step:05d}/{num_batches:05d}] "
1080 f"[{valid_timer.get_time_str()}/{valid_timer.get_total_time_str()}, "
1081 f"{valid_timer.get_step_time_rate_str()}]"
1082 )
1084 # Synchronize all processes at the end of validation
1085 # This ensures all ranks wait for the slowest rank to finish
1086 self.accelerator.wait_for_everyone()
1088 if not calculate_metrics:
1089 return
1091 # Gather predictions from all devices
1092 if valid_predictions:
1093 self.log_if_verbose("Gathering predictions from all devices")
1094 # Use use_gather_object=True for Python lists to ensure proper gathering
1095 valid_predictions = self.accelerator.gather_for_metrics(valid_predictions, use_gather_object=True)
1096 valid_targets = self.accelerator.gather_for_metrics(valid_targets, use_gather_object=True)
1097 valid_prediction_ids = self.accelerator.gather_for_metrics(valid_prediction_ids, use_gather_object=True)
1099 # Flatten nested lists if gather_for_metrics returned nested structure (one per device)
1100 # gather_for_metrics with use_gather_object=True returns a list of lists when num_processes > 1
1101 # Structure: [[pred1, pred2, ...], [pred3, pred4, ...], ...] where each inner list is from one device
1102 # We detect this by checking if we have num_processes or fewer top-level lists, and all are lists
1103 if self.accelerator.num_processes > 1 and valid_predictions:
1104 # If the length matches num_processes (or less, if some processes had no data)
1105 # and all elements are lists, it's likely the nested structure from gathering
1106 if len(valid_predictions) <= self.accelerator.num_processes and all(isinstance(item, list) for item in valid_predictions):
1107 # Flatten the nested structure
1108 valid_predictions = [item for sublist in valid_predictions for item in sublist]
1109 valid_targets = [item for sublist in valid_targets for item in sublist]
1110 valid_prediction_ids = [item for sublist in valid_prediction_ids for item in sublist] # type: ignore[attr-defined]
1112 # Validate that all gathered lists have matching lengths after flattening
1113 if len(valid_predictions) != len(valid_targets) or len(valid_predictions) != len(valid_prediction_ids):
1114 raise ValueError(
1115 f"Length mismatch after gathering predictions from all devices. "
1116 f"valid_predictions: {len(valid_predictions)}, "
1117 f"valid_targets: {len(valid_targets)}, "
1118 f"valid_prediction_ids: {len(valid_prediction_ids)}. "
1119 f"num_processes: {self.accelerator.num_processes}"
1120 )
1122 # Convert to numpy array for np.unique
1123 valid_prediction_ids_array = np.array(valid_prediction_ids)
1125 # Use valid_prediction_ids to remove duplicates
1126 # Find the indices of the first occurrence of each unique prediction_id
1127 _, idx = np.unique(valid_prediction_ids_array, return_index=True)
1129 # Store original length before deduplication for validation
1130 original_length = len(valid_predictions)
1132 # Validate indices are within bounds - this should never happen if lengths match
1133 max_idx = len(valid_predictions) - 1
1134 if len(idx) > 0 and idx.max() > max_idx:
1135 raise IndexError(
1136 f"IndexError: max index {idx.max()} exceeds valid_predictions length {len(valid_predictions)}. "
1137 f"valid_prediction_ids length: {len(valid_prediction_ids)}, "
1138 f"unique prediction_ids: {len(idx)}, "
1139 f"num_processes: {self.accelerator.num_processes}. "
1140 )
1142 valid_predictions = [valid_predictions[i] for i in idx]
1143 valid_targets = [valid_targets[i] for i in idx]
1145 self.log_if_verbose(f"Gathered {len(valid_predictions)} predictions")
1147 # Gather validation groups if available
1148 if self.using_validation_groups:
1149 valid_groups = self.accelerator.gather_for_metrics(valid_groups, use_gather_object=True)
1150 # Flatten nested structure if needed (same logic as above)
1151 if self.accelerator.num_processes > 1 and valid_groups:
1152 if len(valid_groups) <= self.accelerator.num_processes and all(isinstance(item, list) for item in valid_groups):
1153 valid_groups = [item for sublist in valid_groups for item in sublist]
1155 # Validate length matches the original length before deduplication
1156 if len(valid_groups) != original_length:
1157 raise ValueError(
1158 f"Length mismatch for valid_groups. "
1159 f"valid_groups length: {len(valid_groups)}, "
1160 f"expected length (before dedup): {original_length}, "
1161 f"deduplicated predictions length: {len(idx)}. "
1162 )
1163 valid_groups = [valid_groups[i] for i in idx]
1165 # Gather valid_metrics from all devices
1166 for metric, values in valid_metrics.items():
1167 valid_metrics[metric] = self.accelerator.gather_for_metrics(values)
1169 # Keep validation metrics for checkpointing
1170 checkpoint_metric = self.config.get("checkpoint_metric", None)
1171 if checkpoint_metric is not None:
1172 self.last_validation_metric = np.mean(valid_metrics[checkpoint_metric])
1174 # Log validation metrics
1175 if self.accelerator.is_main_process and self.metrics is not None:
1176 # Validation metrics are logged by epoch
1177 validation_step = self.global_step + 1
1179 if self.sw is not None:
1180 for k, v in valid_metrics.items():
1181 self.sw.add_scalar(f"eval/{k}", np.mean(v), validation_step)
1183 logger.info(
1184 f"[VALIDATION] [Epoch {self.epoch:02d}] "
1185 f"[Step {self.global_step + 1:06d}] "
1186 f"train_loss={self.running_loss if self.running_loss else 0:.5f}, "
1187 f"valid_loss={np.mean(valid_metrics['valid_loss']):.5f}"
1188 )
1189 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] [Step {self.global_step + 1:06d}] Metrics:")
1190 for metric in ["aa_er", "aa_prec", "aa_recall", "pep_recall"]:
1191 val = np.mean(valid_metrics[metric])
1192 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] [Step {self.global_step + 1:06d}] - {metric:11s}{val:.3f}")
1194 # Validation group logging
1195 if self.using_validation_groups and valid_groups and self.sw is not None:
1196 preds = pl.Series(valid_predictions)
1197 targs = pl.Series(valid_targets)
1198 groups = pl.Series(valid_groups)
1200 assert len(preds) == len(groups)
1201 assert len(targs) == len(groups)
1203 for group in groups.unique():
1204 idx = groups == group
1205 logger.info(f"Computing group {group} with {idx.sum()} samples")
1206 if idx.sum() > 0: # Only compute metrics if we have samples for this group
1207 aa_prec, aa_recall, pep_recall, _ = self.metrics.compute_precision_recall(targs.filter(idx), preds.filter(idx))
1208 aa_er = self.metrics.compute_aa_er(targs.filter(idx), preds.filter(idx))
1209 self.sw.add_scalar(f"eval/{group}_aa_er", aa_er, validation_step)
1210 self.sw.add_scalar(f"eval/{group}_aa_prec", aa_prec, validation_step)
1211 self.sw.add_scalar(f"eval/{group}_aa_recall", aa_recall, validation_step)
1212 self.sw.add_scalar(f"eval/{group}_pep_recall", pep_recall, validation_step)
1214 self.accelerator.wait_for_everyone()