Coverage for instanovo/common/trainer.py: 44%

643 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import datetime 

4import logging 

5import math 

6import os 

7import shutil 

8import sys 

9import traceback 

10from abc import ABCMeta, abstractmethod 

11from collections import Counter 

12from datetime import timedelta 

13from pathlib import Path 

14from typing import Any, Dict, Iterable, List 

15 

16import neptune 

17import numpy as np 

18import pandas as pd 

19import polars as pl 

20import torch 

21import torch.nn as nn 

22from accelerate import Accelerator 

23from accelerate.utils import DataLoaderConfiguration, InitProcessGroupKwargs, broadcast_object_list 

24from datasets import Dataset, Value 

25from datasets.utils.logging import disable_progress_bar 

26from dotenv import load_dotenv 

27from neptune.integrations.python_logger import NeptuneHandler 

28from omegaconf import DictConfig, OmegaConf 

29from sklearn.model_selection import train_test_split 

30from torch.utils.tensorboard import SummaryWriter 

31 

32from instanovo.__init__ import console, set_rank 

33from instanovo.common.dataset import DataProcessor 

34from instanovo.common.scheduler import CosineWarmupScheduler, FinetuneScheduler, WarmupScheduler 

35from instanovo.common.utils import ( 

36 NeptuneSummaryWriter, 

37 Timer, 

38 TrainingState, 

39 _get_filepath_mapping, 

40 _set_author_neptune_api_token, 

41) 

42from instanovo.constants import ANNOTATED_COLUMN, ANNOTATION_ERROR, SHUFFLE_BUFFER_SIZE, MSColumns 

43from instanovo.inference import Decoder 

44from instanovo.utils.colorlogging import ColorLog 

45from instanovo.utils.data_handler import SpectrumDataFrame 

46from instanovo.utils.device_handler import validate_and_configure_device 

47from instanovo.utils.metrics import Metrics 

48from instanovo.utils.residues import ResidueSet 

49from instanovo.utils.s3 import S3FileHandler 

50 

51load_dotenv() 

52 

53# Automatic rank logger 

54logger = ColorLog(console, __name__).logger 

55 

56 

57class AccelerateDeNovoTrainer(metaclass=ABCMeta): 

58 """Trainer class that uses the Accelerate library.""" 

59 

60 @property 

61 def run_id(self) -> str: 

62 """Get the run ID. 

63 

64 Returns: 

65 str: The run ID 

66 """ 

67 return str(self._run_id) 

68 

69 @property 

70 def s3(self) -> S3FileHandler: 

71 """Get the S3 file handler. 

72 

73 Returns: 

74 S3FileHandler: The S3 file handler 

75 """ 

76 return self._s3 

77 

78 @property 

79 def global_step(self) -> int: 

80 """Get the current global training step. 

81 

82 This represents the total number of training steps across all epochs. 

83 

84 Returns: 

85 int: The current global step number 

86 """ 

87 return int(self._training_state.global_step) 

88 

89 @property 

90 def epoch(self) -> int: 

91 """Get the current training epoch. 

92 

93 This represents the current epoch number in the training process. 

94 

95 Returns: 

96 int: The current epoch number 

97 """ 

98 return int(self._training_state.epoch) 

99 

100 @property 

101 def training_state(self) -> TrainingState: 

102 """Get the training state.""" 

103 return self._training_state 

104 

105 def __init__( 

106 self, 

107 config: DictConfig, 

108 ) -> None: 

109 self.config = config 

110 self.enable_verbose_logging = self.config.get("enable_verbose_logging", True) 

111 if not self.config.get("enable_verbose_accelerate", True): 

112 logging.getLogger("accelerate").setLevel(logging.WARNING) 

113 

114 # Hide progress bar from HF datasets 

115 disable_progress_bar() 

116 

117 # Training state 

118 # Keeps track of the global step and epoch 

119 # Used for accelerate training state checkpointing 

120 self._training_state = TrainingState() 

121 

122 self._run_id = self.config.get("run_name", "instanovo") + datetime.datetime.now().strftime("_%y_%m_%d_%H_%M") 

123 

124 self.accelerator = self.setup_accelerator() 

125 

126 self.log_if_verbose("Verbose logging enabled") 

127 

128 if self.accelerator.is_main_process: 

129 logger.info(f"Config:\n{OmegaConf.to_yaml(self.config)}") 

130 

131 self.residue_set = ResidueSet( 

132 residue_masses=self.config.residues.get("residues"), 

133 residue_remapping=self.config.dataset.get("residue_remapping", None), 

134 ) 

135 logger.info(f"Vocab: {self.residue_set.index_to_residue}") 

136 

137 # Initialise S3 file handler 

138 self._s3: S3FileHandler = S3FileHandler(verbose=self.config.get("enable_verbose_s3", True)) 

139 

140 self.train_dataset, self.valid_dataset, train_size, valid_size = self.load_datasets() 

141 

142 logger.info(f"Data loaded from {train_size:,} training samples and {valid_size:,} validation samples (unfiltered values)") 

143 

144 self.train_dataloader, self.valid_dataloader = self.build_dataloaders(self.train_dataset, self.valid_dataset) 

145 logger.info("Data loaders built") 

146 

147 # Print sample batch 

148 self.print_sample_batch() 

149 

150 logger.info("Setting up model...") 

151 self.model = self.setup_model() 

152 

153 if self.accelerator.is_main_process: 

154 logger.info(f"Model has {sum(p.numel() for p in self.model.parameters()):,d} parameters") 

155 

156 self.optimizer = self.setup_optimizer() 

157 self.lr_scheduler = self.setup_scheduler() 

158 

159 self.decoder = self.setup_decoder() 

160 self.metrics = self.setup_metrics() 

161 

162 # Optionally load a model state for fine-tuning 

163 # Note: will be overwritten by the accelerator state if resuming 

164 if self.config.get("resume_checkpoint_path", None) is not None: 

165 self.load_model_state() # TODO check for loading on mps 

166 

167 # Prepare for accelerated training 

168 ( 

169 self.model, 

170 self.optimizer, 

171 self.lr_scheduler, 

172 self.train_dataloader, 

173 self.valid_dataloader, 

174 ) = self.accelerator.prepare( 

175 self.model, 

176 self.optimizer, 

177 self.lr_scheduler, 

178 self.train_dataloader, 

179 self.valid_dataloader, 

180 ) 

181 # Make sure the training state is checkpointed 

182 self.accelerator.register_for_checkpointing(self._training_state) 

183 

184 # Optionally load states if resuming a training run 

185 if self.config.get("resume_accelerator_state", None): 

186 # Resuming from an existing run 

187 self.load_accelerator_state() 

188 

189 # Setup logging 

190 self.setup_neptune() 

191 self.setup_tensorboard() 

192 self._add_commit_message_to_monitoring_platform() 

193 self._add_config_summary_to_monitoring_platform() 

194 

195 # Training control variables 

196 self.running_loss = None 

197 

198 self.total_steps = self.config.get("training_steps", 2_500_000) 

199 

200 # Setup finetuning scheduler 

201 if self.config.get("finetune", None): 

202 self.finetune_scheduler: FinetuneScheduler | None = FinetuneScheduler( 

203 self.model.state_dict(), 

204 self.config.get("finetune"), 

205 ) 

206 else: 

207 self.finetune_scheduler = None 

208 

209 self.steps_per_validation = self.config.get("validation_interval", 100_000) 

210 self.steps_per_checkpoint = self.config.get("checkpoint_interval", 100_000) 

211 

212 # Print training control variables 

213 if self.accelerator.is_main_process: 

214 steps_per_epoch = train_size // self.config["train_batch_size"] 

215 logger.info("Training setup complete.") 

216 logger.info(f" - Steps per validation: {self.steps_per_validation:,d} ") 

217 logger.info(f" - Steps per checkpoint: {self.steps_per_checkpoint:,d} ") 

218 logger.info(f" - Total training steps: {self.total_steps:,d}") 

219 logger.info("Estimating steps per epoch based on unfiltered training set size:") 

220 logger.info(f" - Estimated steps per epoch: {steps_per_epoch:,d}") 

221 logger.info(f" - Estimated total epochs: {self.total_steps / steps_per_epoch:.1f}") 

222 

223 if self.total_steps < steps_per_epoch: 

224 logger.warning("Total steps is less than estimated steps per epoch, this may result in less than one epoch during training") 

225 

226 if self.global_step > 0: 

227 logger.info(f"Training will resume from epoch {self.epoch}, global_step {self.global_step}") 

228 

229 self.last_validation_metric = None 

230 self.best_checkpoint_metric = None 

231 

232 # Final sync after setup 

233 self.accelerator.wait_for_everyone() 

234 

235 @abstractmethod 

236 def setup_model(self) -> nn.Module: 

237 """Setup the model.""" 

238 ... 

239 

240 @abstractmethod 

241 def setup_optimizer(self) -> torch.optim.Optimizer: 

242 """Setup the optimizer.""" 

243 ... 

244 

245 @abstractmethod 

246 def setup_decoder(self) -> Decoder: 

247 """Setup the decoder.""" 

248 ... 

249 

250 @abstractmethod 

251 def setup_data_processors(self) -> tuple[DataProcessor, DataProcessor]: 

252 """Setup the data processor.""" 

253 ... 

254 

255 @abstractmethod 

256 def save_model(self, is_best_checkpoint: bool = False) -> None: 

257 """Save the model.""" 

258 ... 

259 

260 @abstractmethod 

261 def forward(self, batch: Any) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 

262 """Forward pass for the model to calculate loss.""" 

263 ... 

264 

265 @abstractmethod 

266 def get_predictions(self, batch: Any) -> tuple[list[str] | list[list[str]], list[str] | list[list[str]]]: 

267 """Get the predictions for a batch.""" 

268 ... 

269 

270 @staticmethod 

271 def convert_interval_to_steps(interval: float | int, steps_per_epoch: int) -> int: 

272 """Convert an interval to steps. 

273 

274 Args: 

275 interval (float | int): The interval to convert. 

276 steps_per_epoch (int): The number of steps per epoch. 

277 

278 Returns: 

279 int: The number of steps. 

280 """ 

281 if isinstance(interval, float): 

282 return int(interval * steps_per_epoch) 

283 else: 

284 raise ValueError(f"Invalid interval: {interval}") 

285 

286 def log_if_verbose(self, message: str, level: str = "info") -> None: 

287 """Log a message if verbose logging is enabled.""" 

288 if self.enable_verbose_logging: 

289 if level == "info": 

290 logger.info(message) 

291 elif level == "warning": 

292 logger.warning(message) 

293 elif level == "error": 

294 logger.error(message) 

295 elif level == "debug": 

296 logger.debug(message) 

297 else: 

298 raise ValueError(f"Invalid level: {level}") 

299 

300 def setup_metrics(self) -> Metrics: 

301 """Setup the metrics.""" 

302 return Metrics(self.residue_set, self.config.get("max_isotope_error", 1)) 

303 

304 def setup_accelerator(self) -> Accelerator: 

305 """Setup the accelerator.""" 

306 timeout = timedelta(seconds=self.config.get("timeout", 3600)) 

307 validate_and_configure_device(self.config) 

308 accelerator = Accelerator( 

309 cpu=self.config.get("force_cpu", False), 

310 mixed_precision="fp16" if torch.cuda.is_available() and not self.config.get("force_cpu", False) else "no", 

311 gradient_accumulation_steps=self.config.get("grad_accumulation", 1), 

312 dataloader_config=DataLoaderConfiguration(split_batches=True), 

313 kwargs_handlers=[InitProcessGroupKwargs(timeout=timeout)], 

314 ) 

315 

316 device = accelerator.device # Important, this forces ranks to choose a device. 

317 

318 if accelerator.num_processes > 1: 

319 set_rank(accelerator.local_process_index) 

320 

321 if accelerator.is_main_process: 

322 logger.info(f"Python version: {sys.version}") 

323 logger.info(f"Torch version: {torch.__version__}") 

324 logger.info(f"CUDA version: {torch.version.cuda}") 

325 logger.info(f"Training with {accelerator.num_processes} devices") 

326 logger.info(f"Per-device batch size: {self.config['train_batch_size']}") 

327 logger.info(f"Gradient accumulation steps: {self.config['grad_accumulation']}") 

328 effective_batch_size = self.config["train_batch_size"] * accelerator.num_processes * self.config["grad_accumulation"] 

329 logger.info(f"Effective batch size: {effective_batch_size}") 

330 

331 logger.info(f"Using device: {device}") 

332 

333 return accelerator 

334 

335 def build_dataloaders(self, train_dataset: Dataset, valid_dataset: Dataset) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: 

336 """Setup the dataloaders.""" 

337 train_processor, valid_processor = self.setup_data_processors() 

338 

339 valid_processor.add_metadata_columns(["prediction_id"]) 

340 if self.using_validation_groups: 

341 valid_processor.add_metadata_columns(["validation_group"]) 

342 

343 if self.config.get("use_shuffle_buffer", True): 

344 buffer_size = self.config.get("shuffle_buffer_size", SHUFFLE_BUFFER_SIZE) 

345 train_dataset = train_dataset.shuffle(buffer_size=buffer_size, seed=42) 

346 

347 train_dataset = train_dataset.map( 

348 train_processor.process_row, 

349 ) 

350 valid_dataset = valid_processor.process_dataset(valid_dataset) 

351 

352 pin_memory = self.config.get("pin_memory", False) 

353 if self.accelerator.device == torch.device("cpu") or self.config.get("mps", False): 

354 pin_memory = False 

355 # Scale batch size by number of processes when using split_batches 

356 train_dataloader = torch.utils.data.DataLoader( 

357 train_dataset, 

358 batch_size=self.config["train_batch_size"] * self.accelerator.num_processes, 

359 collate_fn=train_processor.collate_fn, 

360 num_workers=self.config.get("num_workers", 8), 

361 pin_memory=pin_memory, 

362 prefetch_factor=self.config.get("prefetch_factor", None), 

363 drop_last=True, 

364 ) 

365 valid_dataloader = torch.utils.data.DataLoader( 

366 valid_dataset, 

367 batch_size=self.config["predict_batch_size"] * self.accelerator.num_processes, 

368 collate_fn=valid_processor.collate_fn, 

369 num_workers=self.config.get("num_workers", 8), 

370 pin_memory=pin_memory, 

371 prefetch_factor=self.config.get("prefetch_factor", None), 

372 drop_last=False, 

373 ) 

374 return train_dataloader, valid_dataloader 

375 

376 def setup_scheduler(self) -> torch.optim.lr_scheduler.LRScheduler: 

377 """Setup the learning rate scheduler. 

378 

379 Returns: 

380 torch.optim.lr_scheduler.LRScheduler: The learning rate scheduler 

381 """ 

382 # Note: if split_batches is False, the scheduler will be called num_processes times 

383 # in each optimizer step. Therefore, we need to scale the scheduler steps by num_processes. 

384 # Default is split_batches is True 

385 if self.config.get("lr_scheduler", "warmup") == "warmup": 

386 warmup_steps = self.config.get("warmup_iters", 1000) # * num_processes 

387 return WarmupScheduler(self.optimizer, warmup_steps) 

388 elif self.config.get("lr_scheduler", None) == "cosine": 

389 # Scale max_iters based on accumulation 

390 # train_dataloader is already scaled by num_processes 

391 max_iters = self.config.get("training_steps", 2_500_000) # * num_processes 

392 warmup_steps = self.config.get("warmup_iters", 1000) # * num_processes 

393 return CosineWarmupScheduler(self.optimizer, warmup_steps, max_iters) 

394 else: 

395 raise ValueError(f"Unknown lr_scheduler type '{self.config.get('lr_scheduler', None)}'") 

396 

397 def setup_neptune(self) -> None: 

398 """Setup the neptune.""" 

399 if not self.accelerator.is_main_process: 

400 self.neptune_run = None 

401 return 

402 

403 if not self.config.get("use_neptune", True): 

404 self.neptune_run = None 

405 return 

406 

407 _set_author_neptune_api_token() 

408 try: 

409 self.neptune_run = neptune.init_run( 

410 with_id=None, 

411 name=self.run_id, 

412 dependencies=str(Path(__file__).parent.parent.parent / "uv.lock"), 

413 tags=OmegaConf.to_object(self.config.get("tags", [])), 

414 ) 

415 self.neptune_run.assign({"config": OmegaConf.to_yaml(self.config)}) 

416 logger.addHandler(NeptuneHandler(run=self.neptune_run)) 

417 except Exception as e: 

418 logger.warning(f"Failed to initialise neptune: {e}") 

419 self.neptune_run = None 

420 

421 def setup_tensorboard(self) -> None: 

422 """Setup the tensorboard.""" 

423 if not self.accelerator.is_main_process: 

424 self.sw = None 

425 return 

426 

427 if S3FileHandler.register_tb(): 

428 logs_path = os.environ["AICHOR_LOGS_PATH"] 

429 else: 

430 logs_path = self.config.get("tb_summarywriter", "runs") + self.run_id 

431 

432 if self.neptune_run is not None: 

433 self.sw = NeptuneSummaryWriter(logs_path, self.neptune_run) 

434 else: 

435 self.sw = SummaryWriter(logs_path) 

436 logger.info(f"TensorBoard logs will be saved to {logs_path}") 

437 

438 @property 

439 def _is_main_process_on_aichor(self) -> bool: 

440 """Return True if monitoring logging is configured and this is the main process.""" 

441 return ("AICHOR_LOGS_PATH" in os.environ) and self.accelerator.is_main_process 

442 

443 def _add_commit_message_to_monitoring_platform(self, commit_id_length: int = 7) -> None: 

444 """Add the git commit message to the monitoring platform.""" 

445 if not self._is_main_process_on_aichor: 

446 logger.debug("Skipping config summary upload to Neptune: 'AICHOR_LOGS_PATH' not set or current process is not the main AICHOR process") 

447 return 

448 

449 try: 

450 # Remove 'exp:' prefix if present in the commit message and also remove the space after it. 

451 git_commit_msg = os.environ["VCS_COMMIT_MESSAGE"].removeprefix("exp:").removeprefix(" ") 

452 commit_short_hash = os.environ["VCS_SHA"][:commit_id_length] 

453 self.sw.add_text( # type: ignore[union-attr] 

454 "git/commit_message", f"{git_commit_msg} ({commit_short_hash})" 

455 ) 

456 except (AttributeError, KeyError) as exc: 

457 logger.warning("Failed to write config summary to the monitoring plateform", exc_info=exc) 

458 

459 def _add_config_summary_to_monitoring_platform(self) -> None: 

460 """Add the config summary to the monitoring platform.""" 

461 if not self._is_main_process_on_aichor: 

462 logger.debug("Skipping config summary upload to Neptune: 'AICHOR_LOGS_PATH' not set or current process is not the main AICHOR process") 

463 return 

464 # https://github.com/pytorch/pytorch/blob/daca611465c93ac6b8147e6b7070ce2b4254cfc5/torch/utils/tensorboard/summary.py#L244 # noqa 

465 self.sw.add_hparams( # type: ignore[union-attr] 

466 {k: v for k, v in self.config.items() if isinstance(v, (int, float, str))}, {} 

467 ) 

468 

469 def load_datasets(self) -> tuple[Dataset, Dataset, int, int]: 

470 """Load the training and validation datasets. 

471 

472 Returns: 

473 tuple[SpectrumDataFrame, SpectrumDataFrame]: 

474 The training and validation datasets 

475 """ 

476 validation_group_mapping = None 

477 dataset_config = self.config.get("dataset", {}) 

478 try: 

479 logger.info("Loading training dataset...") 

480 train_sdf = SpectrumDataFrame.load( 

481 source=dataset_config.get("train_path"), 

482 source_type=dataset_config.get("source_type", "default"), 

483 lazy=dataset_config.get("lazy_loading", True), 

484 is_annotated=True, 

485 shuffle=True, 

486 partition=dataset_config.get("train_partition", None), 

487 column_mapping=dataset_config.get("column_remapping", None), 

488 max_shard_size=dataset_config.get("max_shard_size", 100_000), 

489 preshuffle_across_shards=dataset_config.get("preshuffle_shards", False), 

490 verbose=dataset_config.get("verbose_loading", True), 

491 ) 

492 

493 valid_path = dataset_config.get("valid_path", None) 

494 if valid_path is not None: 

495 if OmegaConf.is_dict(valid_path): 

496 logger.info("Found grouped validation datasets.") 

497 validation_group_mapping = _get_filepath_mapping(valid_path) 

498 _valid_path = list(valid_path.values()) 

499 else: 

500 _valid_path = valid_path 

501 else: 

502 _valid_path = dataset_config.get("train_path") 

503 

504 logger.info("Loading validation dataset...") 

505 valid_sdf = SpectrumDataFrame.load( 

506 _valid_path, 

507 lazy=dataset_config.get("lazy_loading", True), 

508 is_annotated=True, 

509 shuffle=False, 

510 partition=dataset_config.get("valid_partition", None), 

511 column_mapping=dataset_config.get("column_remapping", None), 

512 max_shard_size=dataset_config.get("max_shard_size", 100_000), 

513 add_source_file_column=True, # used to track validation groups 

514 verbose=dataset_config.get("verbose_loading", True), 

515 ) 

516 except ValueError as e: 

517 # More descriptive error message in predict mode. 

518 if str(e) == ANNOTATION_ERROR: 

519 raise ValueError("The sequence column is missing annotations, are you trying to run de novo prediction? Add the --denovo flag") from e 

520 else: 

521 raise 

522 

523 # Split data if needed 

524 if dataset_config.get("valid_path", None) is None: 

525 logger.info("Validation path not specified, generating from training set.") 

526 sequences = list(train_sdf.get_unique_sequences()) 

527 sequences = sorted({DataProcessor.remove_modifications(x) for x in sequences}) 

528 train_unique, valid_unique = train_test_split( 

529 sequences, 

530 test_size=dataset_config.get("valid_subset_of_train"), 

531 random_state=42, 

532 ) 

533 train_unique = set(train_unique) 

534 valid_unique = set(valid_unique) 

535 

536 train_sdf.filter_rows(lambda row: DataProcessor.remove_modifications(row[ANNOTATED_COLUMN]) in train_unique) 

537 valid_sdf.filter_rows(lambda row: DataProcessor.remove_modifications(row[ANNOTATED_COLUMN]) in valid_unique) 

538 

539 # Save splits 

540 if self.accelerator.is_main_process: 

541 split_path = os.path.join(self.config.get("model_save_folder_path", "./checkpoints"), "splits.csv") 

542 os.makedirs(os.path.dirname(split_path), exist_ok=True) 

543 splits_df = pd.DataFrame( 

544 { 

545 ANNOTATED_COLUMN: list(train_unique) + list(valid_unique), 

546 "split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique), 

547 } 

548 ) 

549 self.s3.upload_to_s3_wrapper(splits_df.to_csv, split_path, index=False) 

550 logger.info(f"Data splits saved to {split_path}") 

551 

552 train_ds = train_sdf.to_dataset(force_unified_schema=True) 

553 valid_ds = valid_sdf.to_dataset(in_memory=True) 

554 

555 # # Sample subsets if needed 

556 valid_subset = self.config.get("valid_subset", 1.0) 

557 if valid_subset < 1.0: 

558 valid_ds = valid_ds.train_test_split(test_size=valid_subset, seed=42)["test"] 

559 

560 # Check residues 

561 if self.config.get("perform_data_checks", True): 

562 logger.info(f"Checking for unknown residues in {len(train_sdf) + len(valid_sdf):,d} rows.") 

563 supported_residues = set(self.residue_set.vocab) 

564 supported_residues.update(set(self.residue_set.residue_remapping.keys())) 

565 data_residues = set() 

566 data_residues.update(train_sdf.get_vocabulary(self.residue_set.tokenize)) 

567 data_residues.update(valid_sdf.get_vocabulary(self.residue_set.tokenize)) 

568 if len(data_residues - supported_residues) > 0: 

569 logger.warning(f"Found {len(data_residues - supported_residues):,d} unsupported residues! These rows will be dropped.") 

570 self.log_if_verbose(f"New residues found: \n{data_residues - supported_residues}") 

571 self.log_if_verbose(f"Residues supported: \n{supported_residues}") 

572 

573 train_ds = train_ds.filter( 

574 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN]))) 

575 ) 

576 valid_ds = valid_ds.filter( 

577 lambda row: all(residue in supported_residues for residue in set(self.residue_set.tokenize(row[ANNOTATED_COLUMN]))) 

578 ) 

579 

580 logger.info("Checking charge values...") 

581 # Check charge values 

582 precursor_charge_col = MSColumns.PRECURSOR_CHARGE.value 

583 

584 if not train_sdf.check_values(1, self.config.get("max_charge", 10), precursor_charge_col): 

585 logger.warning("Found charge values out of range in training set. These rows will be dropped.") 

586 

587 train_ds = train_ds.filter( 

588 lambda row: (row[precursor_charge_col] <= self.config.get("max_charge", 10)) and (row[precursor_charge_col] > 0) 

589 ) 

590 

591 if not valid_sdf.check_values(1, self.config.get("max_charge", 10), precursor_charge_col): 

592 logger.warning("Found charge values out of range in validation set. These rows will be dropped.") 

593 valid_ds = valid_ds.filter( 

594 lambda row: (row[precursor_charge_col] <= self.config.get("max_charge", 10)) and (row[precursor_charge_col] > 0) 

595 ) 

596 

597 # Create validation groups 

598 # Initialize validation groups if needed 

599 if validation_group_mapping is not None: 

600 logger.info("Computing validation groups.") 

601 validation_groups = [validation_group_mapping.get(row.get("source_file"), "no_group") for row in valid_ds] 

602 valid_ds = valid_ds.add_column("validation_group", validation_groups) 

603 

604 logger.info("Sequences per validation group:") 

605 group_counts = Counter(validation_groups) 

606 for group, count in group_counts.items(): 

607 logger.info(f" - {group}: {count:,d}") 

608 

609 self.using_validation_groups = True 

610 else: 

611 self.using_validation_groups = False 

612 

613 # Force add a unique prediction_id column 

614 # This will be used to order predictions and remove duplicates 

615 valid_ds = valid_ds.add_column("prediction_id", np.arange(len(valid_ds)), feature=Value("int32")) 

616 

617 # Keep track of the train_sdf directory so it isn't garbage collected 

618 self._train_sdf = train_sdf 

619 

620 return train_ds, valid_ds, len(train_sdf), len(valid_sdf) 

621 

622 def print_sample_batch(self) -> None: 

623 """Print a sample batch of the training data.""" 

624 if self.accelerator.is_main_process: 

625 # sample_batch = next(iter(self.train_dataloader)) 

626 sample_batch = next(iter(self.train_dataloader)) 

627 logger.info("Sample batch:") 

628 for key, value in sample_batch.items(): 

629 if isinstance(value, torch.Tensor): 

630 value_shape = value.shape 

631 value_type = value.dtype 

632 else: 

633 value_shape = len(value) 

634 value_type = type(value) 

635 

636 logger.info(f" - {key}: {value_type}, {value_shape}") 

637 

638 def save_accelerator_state(self, is_best_checkpoint: bool = False) -> None: 

639 """Save the accelerator state.""" 

640 checkpoint_dir = self.config.get("model_save_folder_path", "./checkpoints") 

641 

642 if self.config.get("keep_accelerator_every_interval", False): 

643 checkpoint_path = os.path.join( 

644 checkpoint_dir, 

645 "accelerator_state", 

646 f"epoch_{self.epoch}_step_{self.global_step + 1}", 

647 ) 

648 else: 

649 checkpoint_path = os.path.join(checkpoint_dir, "accelerator_state", "latest") 

650 if self.accelerator.is_main_process and Path(checkpoint_path).exists() and Path(checkpoint_path).is_dir(): 

651 shutil.rmtree(checkpoint_path) 

652 

653 if self.accelerator.is_main_process: 

654 os.makedirs(checkpoint_path, exist_ok=True) 

655 

656 self.accelerator.save_state(checkpoint_path) 

657 

658 logger.info(f"Saved accelerator state to {checkpoint_path}") 

659 

660 if self.accelerator.is_main_process and S3FileHandler._aichor_enabled(): 

661 for file in os.listdir(checkpoint_path): 

662 self.s3.upload( 

663 os.path.join(checkpoint_path, file), 

664 S3FileHandler.convert_to_s3_output(os.path.join(checkpoint_path, file)), 

665 ) 

666 

667 # Save best checkpoint and upload to S3 

668 if is_best_checkpoint and self.accelerator.is_main_process: 

669 best_checkpoint_path = os.path.join(checkpoint_dir, "accelerator_state", "best") 

670 if Path(best_checkpoint_path).exists() and Path(best_checkpoint_path).is_dir(): 

671 shutil.rmtree(best_checkpoint_path) 

672 

673 os.makedirs(best_checkpoint_path, exist_ok=True) 

674 

675 for file in os.listdir(checkpoint_path): 

676 full_file = os.path.join(checkpoint_path, file) 

677 best_file = os.path.join(best_checkpoint_path, file) 

678 shutil.copy(full_file, best_file) 

679 if S3FileHandler._aichor_enabled(): 

680 self.s3.upload( 

681 full_file, 

682 S3FileHandler.convert_to_s3_output(best_file), 

683 ) 

684 

685 def check_if_best_checkpoint(self) -> bool: 

686 """Check if the last validation metric is the best metric.""" 

687 if self.config.get("checkpoint_metric", None) is None: 

688 return False 

689 

690 if self.best_checkpoint_metric is None: 

691 self.best_checkpoint_metric = self.last_validation_metric 

692 return True 

693 

694 if self.config.get("checkpoint_metric_mode", "min") == "min": 

695 is_best = self.last_validation_metric <= self.best_checkpoint_metric 

696 elif self.config.get("checkpoint_metric_mode", "min") == "max": 

697 is_best = self.last_validation_metric >= self.best_checkpoint_metric 

698 else: 

699 raise ValueError(f"Unknown checkpoint metric mode: {self.config.get('checkpoint_metric_mode', 'min')}") 

700 

701 if is_best: 

702 self.best_checkpoint_metric = self.last_validation_metric 

703 

704 return is_best 

705 

706 def load_accelerator_state(self) -> None: 

707 """Load the accelerator state.""" 

708 checkpoint_path = self.config.get("resume_accelerator_state", None) 

709 if checkpoint_path is None: 

710 return 

711 

712 if not os.path.isdir(checkpoint_path) and not checkpoint_path.startswith("s3://"): 

713 raise ValueError(f"Accelerator state should be a directory of state files, got {checkpoint_path}") 

714 

715 if S3FileHandler._aichor_enabled() and checkpoint_path.startswith("s3://"): 

716 # raise NotImplementedError("Loading accelerator state from S3 is not implemented.") 

717 

718 if self.accelerator.is_main_process: 

719 local_path = os.path.join(self.s3.temp_dir.name, "accelerator_state") 

720 os.makedirs(local_path, exist_ok=True) 

721 logger.info(f"Downloading checkpoint files from {checkpoint_path} to {local_path}") 

722 

723 # Download all files from the checkpoint folder 

724 checkpoint_files = self.s3.listdir(checkpoint_path) 

725 logger.info(f"Found {len(checkpoint_files)} files") 

726 for file in checkpoint_files: 

727 if file.endswith("/"): # Skip subdirectories 

728 continue 

729 local_file = os.path.join(local_path, os.path.basename(file)) 

730 self.s3.download(f"s3://{file}", local_file) 

731 else: 

732 local_path = None 

733 

734 checkpoint_path = broadcast_object_list([local_path])[0] 

735 logger.info(f"Received checkpoint path: {checkpoint_path}") 

736 

737 assert checkpoint_path is not None, "Failed to broadcast accelerator state across ranks" 

738 

739 # Add safe globals 

740 torch.serialization.add_safe_globals( 

741 [ 

742 np._core.multiarray.scalar, 

743 np.dtypes.Float64DType, 

744 ] 

745 ) 

746 

747 self.accelerator.load_state(checkpoint_path) 

748 logger.info(f"Loaded accelerator state from {checkpoint_path}") 

749 

750 def load_model_state(self) -> None: 

751 """Load the model state.""" 

752 checkpoint_path = self.config.get("resume_checkpoint_path", None) 

753 if checkpoint_path is None: 

754 return 

755 

756 if os.path.isdir(checkpoint_path) and not checkpoint_path.startswith("s3://"): 

757 raise ValueError(f"Checkpoint path should be a file, got {checkpoint_path}") 

758 

759 if self.accelerator.is_main_process: 

760 logger.info(f"Resuming model state from {checkpoint_path}") 

761 local_path = self.s3.get_local_path(checkpoint_path) 

762 else: 

763 local_path = None 

764 

765 local_path = broadcast_object_list([local_path])[0] 

766 

767 assert local_path is not None, "Failed to broadcast model state across ranks" 

768 

769 # TODO: Switch to model.load(), implement model schema 

770 model_data = torch.load(local_path, weights_only=False, map_location="cpu") 

771 # TODO: Remove, only use state_dict 

772 if "model" in model_data: 

773 model_state = model_data["model"] 

774 else: 

775 model_state = model_data["state_dict"] 

776 

777 # Check residues 

778 if "residues" in model_data: 

779 model_residues = dict(model_data["residues"].get("residues", {})) 

780 else: 

781 # Legacy format 

782 model_residues = dict(model_data["config"]["residues"]) 

783 

784 current_residues = self.config.residues.get("residues") 

785 if model_residues != current_residues: 

786 logger.warning( 

787 f"Checkpoint residues do not match current residues.\nCheckpoint residues: {model_residues}\nCurrent residues: {current_residues}" 

788 ) 

789 logger.warning("Updating model state to match current residues.") 

790 model_state = self.update_vocab(model_state) 

791 

792 model_state = self.update_model_state(model_state, model_data["config"]) 

793 self.model.load_state_dict(model_state, strict=False) 

794 logger.info(f"Loaded model state from {local_path}") 

795 

796 def update_model_state(self, model_state: dict[str, torch.Tensor], model_config: DictConfig) -> dict[str, torch.Tensor]: 

797 """Update the model state.""" 

798 return model_state 

799 

800 def update_vocab(self, model_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

801 """Update the vocabulary of the model.""" 

802 # This should call `self._update_vocab` based on model implementation. 

803 raise NotImplementedError("Updating vocabulary is not implemented for the base trainer.") 

804 

805 def _update_vocab( 

806 self, 

807 model_state: dict[str, torch.Tensor], 

808 target_layers: list[str], 

809 resolution: str = "delete", 

810 ) -> dict[str, torch.Tensor]: 

811 """Update the target heads of the model.""" 

812 target_vocab_size = len(self.residue_set) 

813 current_model_state = self.model.state_dict() 

814 hidden_size = self.config.model.get("dim_model", 768) 

815 

816 for layer in target_layers: 

817 if layer not in current_model_state: 

818 logger.warning(f"Layer {layer} not found in current model state.") 

819 continue 

820 tmp = torch.normal( 

821 mean=0, 

822 std=1.0 / np.sqrt(hidden_size), 

823 size=current_model_state[layer].shape, 

824 dtype=current_model_state[layer].dtype, 

825 ) 

826 if "bias" in layer: 

827 # initialise bias to zeros 

828 tmp = torch.zeros_like(tmp) 

829 

830 if resolution == "delete": 

831 del model_state[layer] 

832 elif resolution == "random": 

833 model_state[layer] = tmp 

834 elif resolution == "partial": 

835 tmp[:target_vocab_size] = model_state[layer][: min(tmp.shape[0], target_vocab_size)] 

836 model_state[layer] = tmp 

837 else: 

838 raise ValueError(f"Unknown residue_conflict_resolution type '{resolution}'") 

839 return model_state 

840 

841 def train(self) -> None: 

842 """Train the model.""" 

843 num_sanity_steps = self.config.get("num_sanity_val_steps", 0) 

844 if num_sanity_steps > 0: 

845 logger.info(f"Running sanity validation for {num_sanity_steps} steps...") 

846 self.validate_epoch(num_sanity_steps=num_sanity_steps, calculate_metrics=False) 

847 logger.info("Sanity validation complete.") 

848 

849 if self.config.get("validate_before_training", False): 

850 logger.info("Running pre-validation...") 

851 self.validate_epoch() 

852 logger.info("Pre-validation complete.") 

853 

854 self.train_timer = Timer(self.total_steps) 

855 is_first_epoch = True 

856 logger.info("Starting training...") 

857 while self.global_step < self.total_steps: 

858 self.train_epoch() 

859 self.training_state.step_epoch() 

860 if self.accelerator.is_main_process and is_first_epoch: 

861 is_first_epoch = False 

862 logger.info("First epoch complete:") 

863 logger.info(f"- Actual steps per epoch: {self.global_step}") 

864 logger.info(f"- Actual total epochs: {self.total_steps / self.global_step:.1f}") 

865 

866 logger.info("Training complete.") 

867 

868 def prepare_batch(self, batch: Iterable[Any]) -> Any: 

869 """Prepare a batch for training. 

870 

871 Manually move tensors to accelerator.device since we do not 

872 prepare our dataloaders with the accelerator. 

873 

874 Args: 

875 batch (Iterable[Any]): The batch to prepare. 

876 

877 Returns: 

878 Any: The prepared batch 

879 """ 

880 if isinstance(batch, dict): 

881 return {k: v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 

882 elif isinstance(batch, (list, tuple)): 

883 return [v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v for v in batch] 

884 else: 

885 raise ValueError(f"Unsupported batch type: {type(batch)}") 

886 

887 def train_epoch(self) -> None: 

888 """Train the model for one epoch.""" 

889 total_loss = 0 

890 

891 self.model.train() 

892 self.optimizer.zero_grad() 

893 self.running_loss = None 

894 

895 epoch_timer = Timer() 

896 

897 print_batch_size = True 

898 for batch_count, batch in enumerate(self.train_dataloader): 

899 if print_batch_size: 

900 # Confirm batch size during debugging 

901 self.log_if_verbose(f"Batch {batch_count} shape: {batch['spectra'].shape[0]}") 

902 print_batch_size = False 

903 

904 with self.accelerator.accumulate(self.model): 

905 # Forward pass 

906 loss, loss_components = self.forward(batch) 

907 

908 # Check for NaN/Inf in loss immediately after forward pass 

909 loss_value = loss.item() 

910 if math.isnan(loss_value) or math.isinf(loss_value): 

911 error_msg = ( 

912 f"Invalid loss value detected: {loss_value} (NaN: {math.isnan(loss_value)}, Inf: {math.isinf(loss_value)}). " 

913 f"This occurred at step {self.global_step + 1}, epoch {self.epoch}, batch {batch_count}. " 

914 f"This indicates a serious training problem (e.g., exploding gradients, division by zero, numerical instability). " 

915 f"Stopping training to prevent further issues.\n\n" 

916 f"Loss components: {[f'{k}={v.item()}' for k, v in loss_components.items()]}\n\n" 

917 f"Traceback showing where the loss was computed:\n" 

918 ) 

919 stack_trace = traceback.format_stack() 

920 # Show frames from the forward pass 

921 relevant_frames = stack_trace[:-3][-8:] # Show more frames to see the forward pass 

922 error_msg += "".join(relevant_frames) 

923 raise ValueError(error_msg) 

924 

925 # Backward pass 

926 self.accelerator.backward(loss) 

927 

928 # Update weights 

929 if self.accelerator.sync_gradients: 

930 self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.get("gradient_clip_val", 10.0)) 

931 self.optimizer.step() 

932 self.lr_scheduler.step() 

933 self.optimizer.zero_grad() 

934 

935 # Update timer 

936 self.train_timer.step() 

937 

938 # Update running loss 

939 # Exponentially weighted moving average to smooth noisy batch losses 

940 if self.running_loss is None: 

941 self.running_loss = loss.item() 

942 else: 

943 self.running_loss = 0.99 * self.running_loss + 0.01 * loss.item() 

944 

945 total_loss += loss.item() 

946 

947 # Log progress 

948 if (self.global_step + 1) % int(self.config.get("console_logging_steps", 2000)) == 0: 

949 lr = self.lr_scheduler.get_last_lr()[0] 

950 

951 logger.info( 

952 f"[TRAIN] " 

953 f"[Epoch {self.epoch:02d}] " 

954 f"[Step {self.global_step + 1:06d}/{self.total_steps:06d}] " 

955 f"[{self.train_timer.get_time_str()}/{self.train_timer.get_eta_str()}, " 

956 f"{self.train_timer.get_step_time_rate_str()}]: " 

957 f"train_loss_raw={loss.item():.4f}, " 

958 f"running_loss={self.running_loss:.4f}, LR={lr:.6f}" 

959 ) 

960 

961 # Log to tensorboard 

962 if ( 

963 self.accelerator.is_main_process 

964 and self.sw is not None 

965 and (self.global_step + 1) % int(self.config.get("tensorboard_logging_steps", 500)) == 0 

966 ): 

967 lr = self.lr_scheduler.get_last_lr()[0] 

968 self.sw.add_scalar("train/loss_raw", loss.item(), self.global_step + 1) 

969 if self.running_loss is not None: 

970 self.sw.add_scalar("train/loss_smooth", self.running_loss, self.global_step + 1) 

971 for k, v in loss_components.items(): 

972 if k == "loss": 

973 continue 

974 self.sw.add_scalar(f"train/{k}", v.item(), self.global_step + 1) 

975 self.sw.add_scalar("optim/lr", lr, self.global_step + 1) 

976 self.sw.add_scalar("optim/epoch", self.epoch, self.global_step + 1) 

977 

978 if (self.global_step + 1) % self.steps_per_validation == 0: 

979 self.model.eval() 

980 self.validate_epoch() 

981 logger.info("Validation complete, resuming training...") 

982 self.model.train() 

983 

984 if (self.global_step + 1) % self.steps_per_checkpoint == 0: 

985 is_best_checkpoint = self.check_if_best_checkpoint() 

986 self.save_model(is_best_checkpoint) 

987 if self.config.get("save_accelerator_state", False): 

988 self.save_accelerator_state(is_best_checkpoint) 

989 

990 self.training_state.step() 

991 

992 # Update finetuning scheduler 

993 if self.finetune_scheduler is not None: 

994 self.finetune_scheduler.step(self.global_step) 

995 

996 if self.global_step >= self.total_steps: 

997 break 

998 

999 # Epoch complete 

1000 self.accelerator.wait_for_everyone() 

1001 

1002 epoch_timer.step() 

1003 

1004 # Gather losses from all devices 

1005 gathered_losses = self.accelerator.gather_for_metrics(torch.tensor(total_loss, device=self.accelerator.device)) 

1006 gathered_num_batches = self.accelerator.gather_for_metrics(torch.tensor(batch_count, device=self.accelerator.device)) 

1007 

1008 if self.accelerator.is_main_process and self.sw is not None: 

1009 # Sum the losses and batch counts from all devices 

1010 total_loss_all_devices = gathered_losses.sum().item() 

1011 total_batches_all_devices = gathered_num_batches.sum().item() 

1012 avg_loss = total_loss_all_devices / total_batches_all_devices 

1013 

1014 self.sw.add_scalar("eval/train_loss", avg_loss, self.epoch) 

1015 

1016 logger.info(f"[TRAIN] [Epoch {self.epoch:02d}] Epoch complete, total time {epoch_timer.get_time_str()}") 

1017 

1018 def validate_epoch(self, num_sanity_steps: int | None = None, calculate_metrics: bool = True) -> None: 

1019 """Validate for one epoch.""" 

1020 if self.valid_dataloader is None: 

1021 return 

1022 

1023 if self.accelerator.is_main_process: 

1024 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] Starting validation.") 

1025 

1026 valid_epoch_step = 0 

1027 valid_predictions: List[List[str] | str] = [] 

1028 valid_targets: List[List[str] | str] = [] 

1029 valid_groups: List[str] = [] 

1030 valid_prediction_ids: List[int] = [] 

1031 

1032 valid_metrics: Dict[str, List[float]] = {x: [] for x in ["valid_loss", "aa_er", "aa_prec", "aa_recall", "pep_recall"]} 

1033 

1034 num_batches = len(self.valid_dataloader) 

1035 

1036 valid_timer = Timer(num_batches) 

1037 

1038 for batch_idx, batch in enumerate(self.valid_dataloader): 

1039 if num_sanity_steps is not None and batch_idx >= num_sanity_steps: 

1040 break 

1041 

1042 with torch.no_grad(), self.accelerator.autocast(): 

1043 # Loss calculation 

1044 loss, _ = self.forward(batch) 

1045 # Get actual predictions 

1046 y, targets = self.get_predictions(batch) 

1047 

1048 valid_predictions.extend(y) 

1049 valid_targets.extend(targets) 

1050 valid_prediction_ids.extend([x.item() for x in batch["prediction_id"]]) 

1051 

1052 # Store validation groups if available 

1053 if self.using_validation_groups: 

1054 valid_groups.extend(batch["validation_group"]) 

1055 

1056 # Update metrics 

1057 if self.metrics is not None: 

1058 aa_prec, aa_recall, pep_recall, _ = self.metrics.compute_precision_recall(targets, y) 

1059 aa_er = self.metrics.compute_aa_er(targets, y) 

1060 

1061 valid_metrics["valid_loss"].append(loss.item()) 

1062 valid_metrics["aa_er"].append(aa_er) 

1063 valid_metrics["aa_prec"].append(aa_prec) 

1064 valid_metrics["aa_recall"].append(aa_recall) 

1065 valid_metrics["pep_recall"].append(pep_recall) 

1066 

1067 valid_epoch_step += 1 

1068 

1069 valid_timer.step() 

1070 

1071 # Log progress 

1072 if (valid_epoch_step + 1) % int(self.config.get("console_logging_steps", 2000)) == 0: 

1073 epoch_step = valid_epoch_step % num_batches 

1074 

1075 logger.info( 

1076 f"[VALIDATION] " 

1077 f"[Epoch {self.epoch:02d}] " 

1078 f"[Step {self.global_step + 1:06d}] " 

1079 f"[Batch {epoch_step:05d}/{num_batches:05d}] " 

1080 f"[{valid_timer.get_time_str()}/{valid_timer.get_total_time_str()}, " 

1081 f"{valid_timer.get_step_time_rate_str()}]" 

1082 ) 

1083 

1084 # Synchronize all processes at the end of validation 

1085 # This ensures all ranks wait for the slowest rank to finish 

1086 self.accelerator.wait_for_everyone() 

1087 

1088 if not calculate_metrics: 

1089 return 

1090 

1091 # Gather predictions from all devices 

1092 if valid_predictions: 

1093 self.log_if_verbose("Gathering predictions from all devices") 

1094 # Use use_gather_object=True for Python lists to ensure proper gathering 

1095 valid_predictions = self.accelerator.gather_for_metrics(valid_predictions, use_gather_object=True) 

1096 valid_targets = self.accelerator.gather_for_metrics(valid_targets, use_gather_object=True) 

1097 valid_prediction_ids = self.accelerator.gather_for_metrics(valid_prediction_ids, use_gather_object=True) 

1098 

1099 # Flatten nested lists if gather_for_metrics returned nested structure (one per device) 

1100 # gather_for_metrics with use_gather_object=True returns a list of lists when num_processes > 1 

1101 # Structure: [[pred1, pred2, ...], [pred3, pred4, ...], ...] where each inner list is from one device 

1102 # We detect this by checking if we have num_processes or fewer top-level lists, and all are lists 

1103 if self.accelerator.num_processes > 1 and valid_predictions: 

1104 # If the length matches num_processes (or less, if some processes had no data) 

1105 # and all elements are lists, it's likely the nested structure from gathering 

1106 if len(valid_predictions) <= self.accelerator.num_processes and all(isinstance(item, list) for item in valid_predictions): 

1107 # Flatten the nested structure 

1108 valid_predictions = [item for sublist in valid_predictions for item in sublist] 

1109 valid_targets = [item for sublist in valid_targets for item in sublist] 

1110 valid_prediction_ids = [item for sublist in valid_prediction_ids for item in sublist] # type: ignore[attr-defined] 

1111 

1112 # Validate that all gathered lists have matching lengths after flattening 

1113 if len(valid_predictions) != len(valid_targets) or len(valid_predictions) != len(valid_prediction_ids): 

1114 raise ValueError( 

1115 f"Length mismatch after gathering predictions from all devices. " 

1116 f"valid_predictions: {len(valid_predictions)}, " 

1117 f"valid_targets: {len(valid_targets)}, " 

1118 f"valid_prediction_ids: {len(valid_prediction_ids)}. " 

1119 f"num_processes: {self.accelerator.num_processes}" 

1120 ) 

1121 

1122 # Convert to numpy array for np.unique 

1123 valid_prediction_ids_array = np.array(valid_prediction_ids) 

1124 

1125 # Use valid_prediction_ids to remove duplicates 

1126 # Find the indices of the first occurrence of each unique prediction_id 

1127 _, idx = np.unique(valid_prediction_ids_array, return_index=True) 

1128 

1129 # Store original length before deduplication for validation 

1130 original_length = len(valid_predictions) 

1131 

1132 # Validate indices are within bounds - this should never happen if lengths match 

1133 max_idx = len(valid_predictions) - 1 

1134 if len(idx) > 0 and idx.max() > max_idx: 

1135 raise IndexError( 

1136 f"IndexError: max index {idx.max()} exceeds valid_predictions length {len(valid_predictions)}. " 

1137 f"valid_prediction_ids length: {len(valid_prediction_ids)}, " 

1138 f"unique prediction_ids: {len(idx)}, " 

1139 f"num_processes: {self.accelerator.num_processes}. " 

1140 ) 

1141 

1142 valid_predictions = [valid_predictions[i] for i in idx] 

1143 valid_targets = [valid_targets[i] for i in idx] 

1144 

1145 self.log_if_verbose(f"Gathered {len(valid_predictions)} predictions") 

1146 

1147 # Gather validation groups if available 

1148 if self.using_validation_groups: 

1149 valid_groups = self.accelerator.gather_for_metrics(valid_groups, use_gather_object=True) 

1150 # Flatten nested structure if needed (same logic as above) 

1151 if self.accelerator.num_processes > 1 and valid_groups: 

1152 if len(valid_groups) <= self.accelerator.num_processes and all(isinstance(item, list) for item in valid_groups): 

1153 valid_groups = [item for sublist in valid_groups for item in sublist] 

1154 

1155 # Validate length matches the original length before deduplication 

1156 if len(valid_groups) != original_length: 

1157 raise ValueError( 

1158 f"Length mismatch for valid_groups. " 

1159 f"valid_groups length: {len(valid_groups)}, " 

1160 f"expected length (before dedup): {original_length}, " 

1161 f"deduplicated predictions length: {len(idx)}. " 

1162 ) 

1163 valid_groups = [valid_groups[i] for i in idx] 

1164 

1165 # Gather valid_metrics from all devices 

1166 for metric, values in valid_metrics.items(): 

1167 valid_metrics[metric] = self.accelerator.gather_for_metrics(values) 

1168 

1169 # Keep validation metrics for checkpointing 

1170 checkpoint_metric = self.config.get("checkpoint_metric", None) 

1171 if checkpoint_metric is not None: 

1172 self.last_validation_metric = np.mean(valid_metrics[checkpoint_metric]) 

1173 

1174 # Log validation metrics 

1175 if self.accelerator.is_main_process and self.metrics is not None: 

1176 # Validation metrics are logged by epoch 

1177 validation_step = self.global_step + 1 

1178 

1179 if self.sw is not None: 

1180 for k, v in valid_metrics.items(): 

1181 self.sw.add_scalar(f"eval/{k}", np.mean(v), validation_step) 

1182 

1183 logger.info( 

1184 f"[VALIDATION] [Epoch {self.epoch:02d}] " 

1185 f"[Step {self.global_step + 1:06d}] " 

1186 f"train_loss={self.running_loss if self.running_loss else 0:.5f}, " 

1187 f"valid_loss={np.mean(valid_metrics['valid_loss']):.5f}" 

1188 ) 

1189 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] [Step {self.global_step + 1:06d}] Metrics:") 

1190 for metric in ["aa_er", "aa_prec", "aa_recall", "pep_recall"]: 

1191 val = np.mean(valid_metrics[metric]) 

1192 logger.info(f"[VALIDATION] [Epoch {self.epoch:02d}] [Step {self.global_step + 1:06d}] - {metric:11s}{val:.3f}") 

1193 

1194 # Validation group logging 

1195 if self.using_validation_groups and valid_groups and self.sw is not None: 

1196 preds = pl.Series(valid_predictions) 

1197 targs = pl.Series(valid_targets) 

1198 groups = pl.Series(valid_groups) 

1199 

1200 assert len(preds) == len(groups) 

1201 assert len(targs) == len(groups) 

1202 

1203 for group in groups.unique(): 

1204 idx = groups == group 

1205 logger.info(f"Computing group {group} with {idx.sum()} samples") 

1206 if idx.sum() > 0: # Only compute metrics if we have samples for this group 

1207 aa_prec, aa_recall, pep_recall, _ = self.metrics.compute_precision_recall(targs.filter(idx), preds.filter(idx)) 

1208 aa_er = self.metrics.compute_aa_er(targs.filter(idx), preds.filter(idx)) 

1209 self.sw.add_scalar(f"eval/{group}_aa_er", aa_er, validation_step) 

1210 self.sw.add_scalar(f"eval/{group}_aa_prec", aa_prec, validation_step) 

1211 self.sw.add_scalar(f"eval/{group}_aa_recall", aa_recall, validation_step) 

1212 self.sw.add_scalar(f"eval/{group}_pep_recall", pep_recall, validation_step) 

1213 

1214 self.accelerator.wait_for_everyone()