Coverage for instanovo/diffusion/multinomial_diffusion.py: 50%

230 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-06-08 23:00 +0000

1from __future__ import annotations 

2 

3import json 

4import math 

5import os 

6import shutil 

7from importlib import resources 

8from pathlib import Path 

9from typing import Tuple 

10from urllib.parse import urlsplit 

11 

12import torch 

13from jaxtyping import Float, Integer 

14from omegaconf import DictConfig, OmegaConf, open_dict 

15from torch import nn 

16from torch.distributions import Categorical 

17from torch.nn.functional import log_softmax, one_hot 

18 

19from instanovo.__init__ import console 

20from instanovo.diffusion.model import MassSpectrumTransFusion 

21from instanovo.types import Peptide, ResidueLogProbabilities, TimeStep 

22from instanovo.utils.colorlogging import ColorLog 

23from instanovo.utils.file_downloader import download_file 

24from instanovo.utils.residues import ResidueSet 

25from instanovo.utils.s3 import S3FileHandler 

26 

27MODEL_TYPE = "diffusion" 

28 

29logger = ColorLog(console, __name__).logger 

30 

31 

32def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Float[torch.Tensor, " time"]: 

33 """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672 . 

34 

35 Returns alpha parameters, NOT Beta 

36 """ 

37 steps = timesteps + 1 

38 x = torch.linspace(0, timesteps, steps) 

39 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 

40 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 

41 alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] 

42 alphas = torch.clamp(alphas, 0.001, 1.0) 

43 return torch.sqrt(alphas) 

44 

45 

46class InstaNovoPlus(nn.Module): 

47 r"""This class implements Multinomial Diffusion as described in Hoogeboom et al. 2021. 

48 

49 Args: 

50 config (omegaconf.DictConfig): 

51 The model configuration. This should have keys: 

52 - 'name': the model name identifier. 

53 - 'time_steps': the number of time steps in the diffusion process 

54 - 'max_length': the maximum sequence for the model 

55 - 'device': the device where the `Pytorch` model should be 

56 loaded e.g. `cpu`, `cuda:0` etc. 

57 - 'vocab_size': the number of residues in the vocabulary 

58 - 'transition_model': the `DictConfig` for the transition model 

59 

60 This information is necessary for saving and loading the model. 

61 

62 transition_model (nn.Module): 

63 The model that predictions the initial sequence given 

64 the sequence sampled the current time step and the 

65 sequence sampled the previous time step. This is 

66 just a sequence tagging model. 

67 

68 diffusion_schedule (torch.FloatTensor[time_steps]): 

69 The sequence of diffusion probabilities. Note 

70 that `diffusion_schedule[t]` is \alpha_t in 

71 the paper's terminology, not \beta_t. 

72 

73 residue_set (ResidueSet): 

74 The residue vocabulary. This holds a mapping between 

75 residues and indices and residue masses. 

76 """ 

77 

78 config_path: str 

79 schedule_path: str 

80 checkpoint_path: str 

81 

82 def __init__( 

83 self, 

84 config: DictConfig, 

85 transition_model: nn.Module, 

86 diffusion_schedule: Float[torch.Tensor, " time"], 

87 residue_set: ResidueSet, 

88 ) -> None: 

89 super().__init__() 

90 self.config = config 

91 self.time_steps = config.time_steps 

92 self.residue_set = residue_set 

93 self.transition_model = transition_model 

94 self.register_buffer("diffusion_schedule", torch.log(diffusion_schedule)) 

95 self.register_buffer("diffusion_schedule_complement", torch.log(1 - diffusion_schedule)) 

96 self.register_buffer("cumulative_schedule", torch.cumsum(self.diffusion_schedule, -1)) 

97 self.register_buffer( 

98 "cumulative_schedule_complement", 

99 torch.log(1 - torch.exp(self.cumulative_schedule)), 

100 ) 

101 

102 def save( 

103 self, 

104 path: str, 

105 ckpt_details: str, 

106 overwrite: bool = False, 

107 temp_dir: str | None = None, 

108 use_legacy_format: bool = False, 

109 ) -> None: 

110 """Save the model to a directory. 

111 

112 Args: 

113 path (str): 

114 Path to the base directory where the model is saved. 

115 The model is saved in a subdirectory with the model's 

116 name identifier. 

117 

118 ckpt_details (str): 

119 Additional checkpoint details to include in model save directory. 

120 

121 overwrite (bool, optional): 

122 Whether to overwrite the directory if one already exists 

123 for the model. Defaults to False. 

124 

125 temp_dir (str | None, optional): 

126 Temporary directory to save intermediate files to. 

127 Defaults to None. 

128 

129 use_legacy_format (bool, optional): 

130 Whether to save the model in the legacy folder format. 

131 If False, saves as a single file. Defaults to False. 

132 

133 Raises: 

134 FileExistsError: If `overwrite` is `False` and a directory already exists 

135 for the model identifier. 

136 """ 

137 model_dir = os.path.join(path, ckpt_details) 

138 

139 def save_file_local(filename: str, content: str) -> None: 

140 """Save a file locally (no upload).""" 

141 return 

142 

143 def save_file_s3(filename: str, content: str) -> None: 

144 """Upload a file to S3.""" 

145 # TODO: fix this 

146 s3 = S3FileHandler() 

147 return s3.upload( # type: ignore 

148 content, s3.convert_to_s3_output(model_dir + "/" + filename) 

149 ) 

150 

151 if temp_dir is None: 

152 if os.path.exists(model_dir) and os.path.isdir(model_dir): 

153 if overwrite: 

154 shutil.rmtree(model_dir) 

155 else: 

156 raise FileExistsError 

157 

158 if use_legacy_format: 

159 os.makedirs(model_dir, exist_ok=True) 

160 elif os.path.dirname(model_dir): 

161 os.makedirs(os.path.dirname(model_dir), exist_ok=True) 

162 

163 save_path = model_dir 

164 save_file = save_file_local 

165 

166 else: 

167 save_path = temp_dir 

168 save_file = save_file_s3 

169 

170 if use_legacy_format: 

171 # Save model as a folder 

172 # Save config 

173 config_path = os.path.join(save_path, "config.yaml") 

174 OmegaConf.save(config=self.config, f=config_path) 

175 save_file("config.yaml", config_path) 

176 

177 # Save schedule 

178 diff_schedule_path = os.path.join(save_path, "diffusion_schedule.pt") 

179 torch.save(torch.exp(self.diffusion_schedule), diff_schedule_path) 

180 save_file("diffusion_schedule.pt", diff_schedule_path) 

181 

182 # Save transition model 

183 self.transition_model.to("cpu") 

184 transition_model_path = os.path.join(save_path, "transition_model.ckpt") 

185 torch.save(self.transition_model.state_dict(), transition_model_path) 

186 save_file("transition_model.ckpt", transition_model_path) 

187 else: 

188 # Save model as a single file 

189 transition_model_state = {k: v.cpu() for k, v in self.transition_model.state_dict().items()} 

190 

191 model_data = { 

192 "config": OmegaConf.to_container(self.config), 

193 "diffusion_schedule": torch.exp(self.diffusion_schedule).tolist(), 

194 "transition_model": transition_model_state, 

195 } 

196 

197 if temp_dir: 

198 save_path = os.path.join(save_path, "instanovo_plus.ckpt") 

199 torch.save(model_data, save_path) 

200 save_file("instanovo_plus.ckpt", save_path) 

201 else: 

202 torch.save(model_data, save_path) 

203 save_file("instanovo_plus.ckpt", save_path) 

204 

205 @classmethod 

206 def load(cls, path: str, override_config: DictConfig | dict | None = None) -> Tuple[InstaNovoPlus, DictConfig]: 

207 """Load a saved model. 

208 

209 Args: 

210 path (str): 

211 Path to model checkpoint file or directory where model is saved. 

212 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None. 

213 

214 Returns: 

215 (InstaNovoPlus, DictConfig): The loaded model and config. 

216 

217 """ 

218 is_legacy_format = False 

219 if os.path.isdir(path): 

220 # Load config 

221 cls.config_path = os.path.join(path, "config.yaml") 

222 config = OmegaConf.load(cls.config_path) 

223 if override_config is not None: 

224 with open_dict(config): 

225 config.update(override_config) 

226 

227 cls.schedule_path = os.path.join(path, "diffusion_schedule.pt") 

228 diffusion_schedule = torch.load(cls.schedule_path, map_location=torch.device("cpu"), weights_only=True) 

229 

230 cls.checkpoint_path = os.path.join(path, "transition_model.ckpt") 

231 transition_model_state = torch.load(cls.checkpoint_path, map_location=torch.device("cpu"), weights_only=True) 

232 

233 is_legacy_format = True 

234 else: 

235 # Load model from checkpoint file 

236 try: 

237 _whitelist_torch_omegaconf() 

238 model_data = torch.load(path, map_location=torch.device("cpu"), weights_only=True) 

239 except Exception as e: 

240 raise ValueError(f"Failed to load model from {path}: {str(e)}") from e 

241 

242 config = OmegaConf.create(model_data["config"]) 

243 if override_config is not None: 

244 with open_dict(config): 

245 config.update(override_config) 

246 

247 diffusion_schedule = torch.tensor(model_data["diffusion_schedule"]) 

248 if "transition_model" in model_data: 

249 transition_model_state = model_data["transition_model"] 

250 is_legacy_format = True 

251 else: 

252 transition_model_state = model_data["state_dict"] 

253 

254 if is_legacy_format: 

255 # Load residues 

256 residue_set = ResidueSet( 

257 residue_masses=config["residues"], 

258 residue_remapping=config["residue_remapping"], 

259 ) 

260 

261 # Load transition model 

262 transition_model = MassSpectrumTransFusion( 

263 config, 

264 config.max_length, 

265 ) 

266 transition_model.load_state_dict(transition_model_state) 

267 

268 return cls( 

269 config=config, 

270 transition_model=transition_model, 

271 diffusion_schedule=diffusion_schedule, 

272 residue_set=residue_set, 

273 ), config 

274 else: 

275 residues = model_data["residues"] 

276 residue_set = ResidueSet( 

277 residue_masses=residues, 

278 ) 

279 transition_model = MassSpectrumTransFusion( 

280 config, 

281 config.max_length, 

282 ) 

283 

284 model = cls( 

285 config=config, 

286 transition_model=transition_model, 

287 diffusion_schedule=diffusion_schedule, 

288 residue_set=residue_set, 

289 ) 

290 model.load_state_dict(model_data["state_dict"]) 

291 

292 return model, config 

293 

294 @staticmethod 

295 def get_pretrained() -> list[str]: 

296 """Get a list of pretrained model ids.""" 

297 # Load the models.json file 

298 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

299 models_config = json.load(f) 

300 

301 if MODEL_TYPE not in models_config: 

302 return [] 

303 

304 return list(models_config[MODEL_TYPE].keys()) 

305 

306 @classmethod 

307 def from_pretrained(cls, model_id: str, override_config: DictConfig | dict | None = None) -> Tuple["InstaNovoPlus", "DictConfig"]: 

308 """Download and load by model id or model path.""" 

309 # Check if model_id is a local directory 

310 expected_files = ["config.yaml", "diffusion_schedule.pt", "transition_model.ckpt"] 

311 if os.path.isdir(model_id): 

312 if all(os.path.exists(os.path.join(model_id, fn)) for fn in expected_files): 

313 return cls.load(model_id, override_config=override_config) 

314 else: 

315 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(model_id, fn))] 

316 raise FileNotFoundError(f"InstaNovo+ model directory {model_id} is missing the expected file(s): {', '.join(missing_files)}.") 

317 elif os.path.exists(model_id): 

318 return cls.load(model_id, override_config=override_config) 

319 

320 # Load the models.json file 

321 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

322 models_config = json.load(f) 

323 

324 # Find the model in the config 

325 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]: 

326 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]") 

327 

328 # Create cache directory if it doesn't exist 

329 cache_dir = Path.home() / ".cache" / "instanovo" 

330 cache_dir.mkdir(parents=True, exist_ok=True) 

331 

332 model_info = models_config[MODEL_TYPE][model_id] 

333 

334 if "remote" in model_info: 

335 url = model_info["remote"] 

336 

337 # Generate a filename for the cached model 

338 file_name = urlsplit(url).path.split("/")[-1] 

339 cached_file = cache_dir / file_name 

340 

341 # Check if the file is already cached 

342 if not cached_file.exists(): 

343 download_file(url, cached_file, model_id, file_name) 

344 

345 else: 

346 logger.info(f"Model {model_id} already cached at {cached_file}") 

347 

348 try: 

349 # Load and return the model 

350 logger.info(f"Loading model {model_id} (remote)") 

351 return cls.load(str(cached_file), override_config=override_config) 

352 except Exception as e: 

353 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}") 

354 if cached_file.exists(): 

355 cached_file.unlink() 

356 

357 download_file(url, cached_file, model_id, file_name) 

358 logger.info(f"Loading newly downloaded model {model_id}") 

359 return cls.load(str(cached_file), override_config=override_config) 

360 

361 elif "local" in model_info: 

362 instanovo_plus_model = model_info["local"] 

363 if os.path.isdir(instanovo_plus_model): 

364 if all(os.path.exists(os.path.join(instanovo_plus_model, fn)) for fn in expected_files): 

365 logger.info(f"Loading model {model_id} (local)") 

366 return cls.load(instanovo_plus_model, override_config=override_config) 

367 else: 

368 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(instanovo_plus_model, fn))] 

369 raise FileNotFoundError( 

370 f"InstaNovo+ model directory {instanovo_plus_model} is missing the expected file(s): {', '.join(missing_files)}." 

371 ) 

372 elif os.path.exists(instanovo_plus_model): 

373 return cls.load(instanovo_plus_model, override_config=override_config) 

374 else: 

375 raise ValueError( 

376 f"Local model path '{instanovo_plus_model}' must exist, be a directory and containing the files {', '.join(expected_files)}." 

377 ) 

378 else: 

379 raise ValueError(f"Model {model_id} does not have a valid 'remote', 'local' entry in models.json") 

380 

381 def prepare_fine_tuning(self, residue_set: ResidueSet) -> None: 

382 """Prepare a model for fine-tuning on a dataset with a new residue vocabulary. 

383 

384 Args: 

385 residue_set (ResidueSet): The residue vocabulary for the new dataset. 

386 """ 

387 # 1. Update residue set 

388 self.residue_set = residue_set 

389 

390 num_residues = len(self.residue_set) 

391 model_dim = self.config.dim 

392 

393 # 2. Update config 

394 self.config.vocab_size = num_residues 

395 

396 # 3. Update modules 

397 self.transition_model.char_embedding = nn.Embedding(num_embeddings=num_residues, embedding_dim=model_dim) 

398 self.transition_model.head[1] = nn.Linear(model_dim, num_residues) 

399 

400 def mixture_categorical( 

401 self, 

402 log_x: Float[ResidueLogProbabilities, "batch token"], 

403 log_alpha: float, 

404 log_alpha_complement: float, 

405 ) -> Float[ResidueLogProbabilities, "batch token"]: 

406 """A categorical mixture between a base distribution and a uniform distribution. 

407 

408 Args: 

409 log_x (torch.FloatTensor[..., num_classes]): 

410 The base distribution. 

411 

412 log_alpha (float): 

413 The log of the mixture weight. 

414 

415 log_alpha_complement (float): 

416 The log of 1 minus the mixture weight. 

417 

418 Returns: 

419 torch.FloatTensor[..., num_classes]: 

420 The log-probabilities of the mixture. 

421 """ 

422 return torch.logaddexp( 

423 log_x + log_alpha, 

424 log_alpha_complement - math.log(len(self.residue_set)), 

425 ) 

426 

427 def forward( 

428 self, 

429 log_x_t: Float[ResidueLogProbabilities, "batch token"], 

430 log_x_0: Float[ResidueLogProbabilities, "batch token"], 

431 t: Integer[TimeStep, " batch"], 

432 ) -> Float[ResidueLogProbabilities, "batch token"]: 

433 """Calculate the log-posterior of `t-1`-th process values given the 0-th and t-th values. 

434 

435 Args: 

436 log_x_t (torch.FloatTensor[batch_size, sequence_length, num_classes]): 

437 The log one-hot representation of the process values at the `t`-th time step. 

438 

439 log_x_0 (torch.FloatTensor[batch_size, sequence_length, num_classes]): 

440 The log one-hot representation of the process values at the `t`-th time step. 

441 t (int): 

442 The time step. 

443 

444 Returns: 

445 torch.FloatTensor[batch_size, sequence_length, num_classes]: 

446 The log-posterior probabilities of the process values at the `t-1`-th 

447 time step given the values at the 0-th and `t`-th time step 

448 i.e. q( x_{t-1} | x_{t}, x_0 ). 

449 """ 

450 log_prior = self.mixture_categorical( 

451 log_x=log_x_0, 

452 log_alpha=self.cumulative_schedule[t - 1].unsqueeze(-1).unsqueeze(-1), 

453 log_alpha_complement=self.cumulative_schedule_complement[t - 1].unsqueeze(-1).unsqueeze(-1), 

454 ) 

455 log_likelihood = self.mixture_categorical( 

456 log_x=log_x_t, 

457 log_alpha=self.diffusion_schedule[t].unsqueeze(-1).unsqueeze(-1), 

458 log_alpha_complement=self.diffusion_schedule_complement[t].unsqueeze(-1).unsqueeze(-1), 

459 ) 

460 t_mask = (t == 0).unsqueeze(-1).unsqueeze(-1).expand_as(log_x_0) 

461 prior_term = torch.where(t_mask, log_x_0, log_prior) 

462 logits = log_likelihood + prior_term 

463 return torch.log_softmax(logits, -1) 

464 

465 def reverse_distribution( 

466 self, 

467 x_t: Integer[Peptide, "batch token"], 

468 time: Integer[TimeStep, " batch"], 

469 **kwargs: dict, 

470 ) -> Float[ResidueLogProbabilities, "batch token"]: 

471 """Calculate the reverse transition distribution of the diffusion process. 

472 

473 Args: 

474 x_t (torch.LongTensor[batch_size, sequence_length]): 

475 The values at the `t`-th time step of the reverse process. 

476 

477 time (int): 

478 The time step. 

479 

480 Returns: 

481 torch.FloatTensor[batch_size, sequence_length, num_classes]: 

482 The log-probabilities of values for the `t-1`-th time step given 

483 values at the `t`-th time step i.e. `log p( x_{t-1} | x_{t} )`. 

484 """ 

485 log_x_0 = log_softmax(self.transition_model(x_t, t=time, **kwargs), -1) 

486 return self.forward(log_x_t=torch.log(one_hot(x_t, len(self.residue_set))), log_x_0=log_x_0, t=time) 

487 

488 

489class DiffusionLoss(nn.Module): 

490 """Holds logic for calculating the diffusion loss. 

491 

492 Args: 

493 model (InstaNovoPlus): 

494 The multinomial diffusion class. 

495 """ 

496 

497 def __init__(self, model: InstaNovoPlus) -> None: 

498 super().__init__() 

499 self.model = model 

500 

501 self.base_model = model.module if hasattr(model, "module") else model 

502 

503 self.time_steps = self.base_model.time_steps 

504 

505 @staticmethod 

506 def kl_divergence( 

507 log_probs_first: Float[ResidueLogProbabilities, "..."], 

508 log_probs_second: Float[ResidueLogProbabilities, "..."], 

509 ) -> Float[torch.Tensor, "..."]: 

510 """Calculate the Kullback-Liebler divergence between two multinomial distributions. 

511 

512 Args: 

513 log_probs_first (torch.FloatTensor[..., num_classes]): 

514 The log-probabilities of the base distribution. 

515 

516 log_probs_second (torch.FloatTensor[..., num_classes]): 

517 The log-probabilities of the comparison distribution. 

518 

519 Returns: 

520 torch.FloatTensor[1]: 

521 The KL-divergence averaged over all but the final dimension. 

522 """ 

523 return (torch.exp(log_probs_first) * (log_probs_first - log_probs_second)).sum(-1).sum(-1) 

524 

525 def forward(self, x_0: Integer[Peptide, "batch token"], **kwargs: dict) -> Float[torch.Tensor, "1"]: 

526 """Calculate a single Monte Carlo estimate of the multinomial diffusion loss (-ELBO). 

527 

528 Args: 

529 x_0 (torch.LongTensor[batch_size, sequence_length]): 

530 A batch of padded sequences. 

531 

532 Returns: 

533 torch.FloatTensor[1]: 

534 The loss estimate. 

535 """ 

536 # 1. Sample time step 

537 t = torch.randint(0, self.time_steps - 1, (x_0.shape[0],)).to(x_0.device) 

538 

539 # 2. Compute L_t 

540 loss = self._compute_loss(t=t, x_0=x_0, **kwargs).mean() 

541 

542 # 3. Calculate prior KL term 

543 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set))) 

544 final_log_probs = self.base_model.mixture_categorical( 

545 log_x=log_x_0, 

546 log_alpha=self.base_model.cumulative_schedule[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1), 

547 log_alpha_complement=self.base_model.cumulative_schedule_complement[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1), 

548 ) 

549 uniform_log_probs = torch.log(torch.ones_like(final_log_probs) / len(self.base_model.residue_set)) 

550 kl_loss = self.kl_divergence(final_log_probs, uniform_log_probs).mean() 

551 return loss + kl_loss 

552 

553 def _compute_loss( 

554 self, 

555 x_0: Integer[Peptide, "batch token"], 

556 t: Integer[TimeStep, " batch"], 

557 **kwargs: dict, 

558 ) -> Float[torch.Tensor, " batch"]: 

559 # 1. sample x_{t+1} 

560 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set))) 

561 log_probs = self.base_model.mixture_categorical( 

562 log_x=log_x_0, 

563 log_alpha=self.base_model.cumulative_schedule[t].unsqueeze(-1).unsqueeze(-1), 

564 log_alpha_complement=self.base_model.cumulative_schedule_complement[t].unsqueeze(-1).unsqueeze(-1), 

565 ) 

566 x_next = Categorical(logits=log_probs).sample() 

567 

568 # 2. Calculate loss 

569 log_dist = self.base_model.reverse_distribution(x_t=x_next, time=t, **kwargs) 

570 

571 nll_loss = -(one_hot(x_0, num_classes=len(self.base_model.residue_set)) * log_dist).sum(-1).sum(-1) 

572 

573 log_posterior = self.model(log_x_0=log_x_0, log_x_t=torch.log(one_hot(x_next, log_probs.size(-1))), t=t) 

574 denoising_loss = self.kl_divergence(log_posterior, log_dist) 

575 loss = torch.where(t == 0, nll_loss, denoising_loss) 

576 return loss 

577 

578 

579def _whitelist_torch_omegaconf() -> None: 

580 """Whitelist specific classes so checkpoints can be loaded with ``weights_only=True``. 

581 

582 The single-file InstaNovo+ checkpoint stores its config as an OmegaConf 

583 ``DictConfig``, which embeds a handful of non-tensor classes. Loading with 

584 ``weights_only=True`` uses PyTorch's restricted unpickler, which refuses 

585 unknown globals. We explicitly allow-list only the known-safe OmegaConf 

586 classes (and the builtins they reference) so we keep the protection against 

587 arbitrary code execution from untrusted checkpoints. 

588 """ 

589 from collections import defaultdict 

590 from typing import Any 

591 

592 from omegaconf.base import ContainerMetadata, Metadata 

593 from omegaconf.listconfig import ListConfig 

594 from omegaconf.nodes import AnyNode 

595 

596 torch.serialization.add_safe_globals( 

597 [ 

598 DictConfig, 

599 ContainerMetadata, 

600 Metadata, 

601 ListConfig, 

602 AnyNode, 

603 Any, # Only used for type hinting in omegaconf. 

604 defaultdict, 

605 dict, 

606 list, 

607 int, 

608 ] 

609 )