Coverage for instanovo/diffusion/multinomial_diffusion.py: 48%

222 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import json 

4import math 

5import os 

6import shutil 

7from importlib import resources 

8from pathlib import Path 

9from typing import Tuple 

10from urllib.parse import urlsplit 

11 

12import torch 

13from jaxtyping import Float, Integer 

14from omegaconf import DictConfig, OmegaConf, open_dict 

15from torch import nn 

16from torch.distributions import Categorical 

17from torch.nn.functional import log_softmax, one_hot 

18 

19from instanovo.__init__ import console 

20from instanovo.diffusion.model import MassSpectrumTransFusion 

21from instanovo.types import Peptide, ResidueLogProbabilities, TimeStep 

22from instanovo.utils.colorlogging import ColorLog 

23from instanovo.utils.file_downloader import download_file 

24from instanovo.utils.residues import ResidueSet 

25from instanovo.utils.s3 import S3FileHandler 

26 

27MODEL_TYPE = "diffusion" 

28 

29logger = ColorLog(console, __name__).logger 

30 

31 

32def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Float[torch.Tensor, " time"]: 

33 """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672 . 

34 

35 Returns alpha parameters, NOT Beta 

36 """ 

37 steps = timesteps + 1 

38 x = torch.linspace(0, timesteps, steps) 

39 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 

40 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 

41 alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] 

42 alphas = torch.clamp(alphas, 0.001, 1.0) 

43 return torch.sqrt(alphas) 

44 

45 

46class InstaNovoPlus(nn.Module): 

47 r"""This class implements Multinomial Diffusion as described in Hoogeboom et al. 2021. 

48 

49 Args: 

50 config (omegaconf.DictConfig): 

51 The model configuration. This should have keys: 

52 - 'name': the model name identifier. 

53 - 'time_steps': the number of time steps in the diffusion process 

54 - 'max_length': the maximum sequence for the model 

55 - 'device': the device where the `Pytorch` model should be 

56 loaded e.g. `cpu`, `cuda:0` etc. 

57 - 'vocab_size': the number of residues in the vocabulary 

58 - 'transition_model': the `DictConfig` for the transition model 

59 

60 This information is necessary for saving and loading the model. 

61 

62 transition_model (nn.Module): 

63 The model that predictions the initial sequence given 

64 the sequence sampled the current time step and the 

65 sequence sampled the previous time step. This is 

66 just a sequence tagging model. 

67 

68 diffusion_schedule (torch.FloatTensor[time_steps]): 

69 The sequence of diffusion probabilities. Note 

70 that `diffusion_schedule[t]` is \alpha_t in 

71 the paper's terminology, not \beta_t. 

72 

73 residue_set (ResidueSet): 

74 The residue vocabulary. This holds a mapping between 

75 residues and indices and residue masses. 

76 """ 

77 

78 config_path: str 

79 schedule_path: str 

80 checkpoint_path: str 

81 

82 def __init__( 

83 self, 

84 config: DictConfig, 

85 transition_model: nn.Module, 

86 diffusion_schedule: Float[torch.Tensor, " time"], 

87 residue_set: ResidueSet, 

88 ) -> None: 

89 super().__init__() 

90 self.config = config 

91 self.time_steps = config.time_steps 

92 self.residue_set = residue_set 

93 self.transition_model = transition_model 

94 self.register_buffer("diffusion_schedule", torch.log(diffusion_schedule)) 

95 self.register_buffer("diffusion_schedule_complement", torch.log(1 - diffusion_schedule)) 

96 self.register_buffer("cumulative_schedule", torch.cumsum(self.diffusion_schedule, -1)) 

97 self.register_buffer( 

98 "cumulative_schedule_complement", 

99 torch.log(1 - torch.exp(self.cumulative_schedule)), 

100 ) 

101 

102 def save( 

103 self, 

104 path: str, 

105 ckpt_details: str, 

106 overwrite: bool = False, 

107 temp_dir: str | None = None, 

108 use_legacy_format: bool = False, 

109 ) -> None: 

110 """Save the model to a directory. 

111 

112 Args: 

113 path (str): 

114 Path to the base directory where the model is saved. 

115 The model is saved in a subdirectory with the model's 

116 name identifier. 

117 

118 ckpt_details (str): 

119 Additional checkpoint details to include in model save directory. 

120 

121 overwrite (bool, optional): 

122 Whether to overwrite the directory if one already exists 

123 for the model. Defaults to False. 

124 

125 temp_dir (str | None, optional): 

126 Temporary directory to save intermediate files to. 

127 Defaults to None. 

128 

129 use_legacy_format (bool, optional): 

130 Whether to save the model in the legacy folder format. 

131 If False, saves as a single file. Defaults to False. 

132 

133 Raises: 

134 FileExistsError: If `overwrite` is `False` and a directory already exists 

135 for the model identifier. 

136 """ 

137 model_dir = os.path.join(path, ckpt_details) 

138 

139 def save_file_local(filename: str, content: str) -> None: 

140 """Save a file locally (no upload).""" 

141 return 

142 

143 def save_file_s3(filename: str, content: str) -> None: 

144 """Upload a file to S3.""" 

145 # TODO: fix this 

146 s3 = S3FileHandler() 

147 return s3.upload( # type: ignore 

148 content, s3.convert_to_s3_output(model_dir + "/" + filename) 

149 ) 

150 

151 if temp_dir is None: 

152 if os.path.exists(model_dir) and os.path.isdir(model_dir): 

153 if overwrite: 

154 shutil.rmtree(model_dir) 

155 else: 

156 raise FileExistsError 

157 

158 if use_legacy_format: 

159 os.makedirs(model_dir, exist_ok=True) 

160 elif os.path.dirname(model_dir): 

161 os.makedirs(os.path.dirname(model_dir), exist_ok=True) 

162 

163 save_path = model_dir 

164 save_file = save_file_local 

165 

166 else: 

167 save_path = temp_dir 

168 save_file = save_file_s3 

169 

170 if use_legacy_format: 

171 # Save model as a folder 

172 # Save config 

173 config_path = os.path.join(save_path, "config.yaml") 

174 OmegaConf.save(config=self.config, f=config_path) 

175 save_file("config.yaml", config_path) 

176 

177 # Save schedule 

178 diff_schedule_path = os.path.join(save_path, "diffusion_schedule.pt") 

179 torch.save(torch.exp(self.diffusion_schedule), diff_schedule_path) 

180 save_file("diffusion_schedule.pt", diff_schedule_path) 

181 

182 # Save transition model 

183 self.transition_model.to("cpu") 

184 transition_model_path = os.path.join(save_path, "transition_model.ckpt") 

185 torch.save(self.transition_model.state_dict(), transition_model_path) 

186 save_file("transition_model.ckpt", transition_model_path) 

187 else: 

188 # Save model as a single file 

189 transition_model_state = {k: v.cpu() for k, v in self.transition_model.state_dict().items()} 

190 

191 model_data = { 

192 "config": OmegaConf.to_container(self.config), 

193 "diffusion_schedule": torch.exp(self.diffusion_schedule).tolist(), 

194 "transition_model": transition_model_state, 

195 } 

196 

197 if temp_dir: 

198 save_path = os.path.join(save_path, "instanovo_plus.ckpt") 

199 torch.save(model_data, save_path) 

200 save_file("instanovo_plus.ckpt", save_path) 

201 else: 

202 torch.save(model_data, save_path) 

203 save_file("instanovo_plus.ckpt", save_path) 

204 

205 @classmethod 

206 def load(cls, path: str, override_config: DictConfig | dict | None = None) -> Tuple[InstaNovoPlus, DictConfig]: 

207 """Load a saved model. 

208 

209 Args: 

210 path (str): 

211 Path to model checkpoint file or directory where model is saved. 

212 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None. 

213 

214 Returns: 

215 (InstaNovoPlus, DictConfig): The loaded model and config. 

216 

217 """ 

218 is_legacy_format = False 

219 if os.path.isdir(path): 

220 # Load config 

221 cls.config_path = os.path.join(path, "config.yaml") 

222 config = OmegaConf.load(cls.config_path) 

223 if override_config is not None: 

224 with open_dict(config): 

225 config.update(override_config) 

226 

227 cls.schedule_path = os.path.join(path, "diffusion_schedule.pt") 

228 diffusion_schedule = torch.load(cls.schedule_path, map_location=torch.device("cpu"), weights_only=True) 

229 

230 cls.checkpoint_path = os.path.join(path, "transition_model.ckpt") 

231 transition_model_state = torch.load(cls.checkpoint_path, map_location=torch.device("cpu"), weights_only=True) 

232 

233 is_legacy_format = True 

234 else: 

235 # Load model from checkpoint file 

236 try: 

237 model_data = torch.load(path, map_location=torch.device("cpu"), weights_only=False) 

238 except Exception as e: 

239 raise ValueError(f"Failed to load model from {path}: {str(e)}") from e 

240 

241 config = OmegaConf.create(model_data["config"]) 

242 if override_config is not None: 

243 with open_dict(config): 

244 config.update(override_config) 

245 

246 diffusion_schedule = torch.tensor(model_data["diffusion_schedule"]) 

247 if "transition_model" in model_data: 

248 transition_model_state = model_data["transition_model"] 

249 is_legacy_format = True 

250 else: 

251 transition_model_state = model_data["state_dict"] 

252 

253 if is_legacy_format: 

254 # Load residues 

255 residue_set = ResidueSet( 

256 residue_masses=config["residues"], 

257 residue_remapping=config["residue_remapping"], 

258 ) 

259 

260 # Load transition model 

261 transition_model = MassSpectrumTransFusion( 

262 config, 

263 config.max_length, 

264 ) 

265 transition_model.load_state_dict(transition_model_state) 

266 

267 return cls( 

268 config=config, 

269 transition_model=transition_model, 

270 diffusion_schedule=diffusion_schedule, 

271 residue_set=residue_set, 

272 ), config 

273 else: 

274 residues = model_data["residues"] 

275 residue_set = ResidueSet( 

276 residue_masses=residues, 

277 ) 

278 transition_model = MassSpectrumTransFusion( 

279 config, 

280 config.max_length, 

281 ) 

282 

283 model = cls( 

284 config=config, 

285 transition_model=transition_model, 

286 diffusion_schedule=diffusion_schedule, 

287 residue_set=residue_set, 

288 ) 

289 model.load_state_dict(model_data["state_dict"]) 

290 

291 return model, config 

292 

293 @staticmethod 

294 def get_pretrained() -> list[str]: 

295 """Get a list of pretrained model ids.""" 

296 # Load the models.json file 

297 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

298 models_config = json.load(f) 

299 

300 if MODEL_TYPE not in models_config: 

301 return [] 

302 

303 return list(models_config[MODEL_TYPE].keys()) 

304 

305 @classmethod 

306 def from_pretrained(cls, model_id: str, override_config: DictConfig | dict | None = None) -> Tuple["InstaNovoPlus", "DictConfig"]: 

307 """Download and load by model id or model path.""" 

308 # Check if model_id is a local directory 

309 expected_files = ["config.yaml", "diffusion_schedule.pt", "transition_model.ckpt"] 

310 if os.path.isdir(model_id): 

311 if all(os.path.exists(os.path.join(model_id, fn)) for fn in expected_files): 

312 return cls.load(model_id, override_config=override_config) 

313 else: 

314 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(model_id, fn))] 

315 raise FileNotFoundError(f"InstaNovo+ model directory {model_id} is missing the expected file(s): {', '.join(missing_files)}.") 

316 elif os.path.exists(model_id): 

317 return cls.load(model_id, override_config=override_config) 

318 

319 # Load the models.json file 

320 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f: 

321 models_config = json.load(f) 

322 

323 # Find the model in the config 

324 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]: 

325 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]") 

326 

327 # Create cache directory if it doesn't exist 

328 cache_dir = Path.home() / ".cache" / "instanovo" 

329 cache_dir.mkdir(parents=True, exist_ok=True) 

330 

331 model_info = models_config[MODEL_TYPE][model_id] 

332 

333 if "remote" in model_info: 

334 url = model_info["remote"] 

335 

336 # Generate a filename for the cached model 

337 file_name = urlsplit(url).path.split("/")[-1] 

338 cached_file = cache_dir / file_name 

339 

340 # Check if the file is already cached 

341 if not cached_file.exists(): 

342 download_file(url, cached_file, model_id, file_name) 

343 

344 else: 

345 logger.info(f"Model {model_id} already cached at {cached_file}") 

346 

347 try: 

348 # Load and return the model 

349 logger.info(f"Loading model {model_id} (remote)") 

350 return cls.load(str(cached_file), override_config=override_config) 

351 except Exception as e: 

352 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}") 

353 if cached_file.exists(): 

354 cached_file.unlink() 

355 

356 download_file(url, cached_file, model_id, file_name) 

357 logger.info(f"Loading newly downloaded model {model_id}") 

358 return cls.load(str(cached_file), override_config=override_config) 

359 

360 elif "local" in model_info: 

361 instanovo_plus_model = model_info["local"] 

362 if os.path.isdir(instanovo_plus_model): 

363 if all(os.path.exists(os.path.join(instanovo_plus_model, fn)) for fn in expected_files): 

364 logger.info(f"Loading model {model_id} (local)") 

365 return cls.load(instanovo_plus_model, override_config=override_config) 

366 else: 

367 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(instanovo_plus_model, fn))] 

368 raise FileNotFoundError( 

369 f"InstaNovo+ model directory {instanovo_plus_model} is missing the expected file(s): {', '.join(missing_files)}." 

370 ) 

371 elif os.path.exists(instanovo_plus_model): 

372 return cls.load(instanovo_plus_model, override_config=override_config) 

373 else: 

374 raise ValueError( 

375 f"Local model path '{instanovo_plus_model}' must exist, be a directory and containing the files {', '.join(expected_files)}." 

376 ) 

377 else: 

378 raise ValueError(f"Model {model_id} does not have a valid 'remote', 'local' entry in models.json") 

379 

380 def prepare_fine_tuning(self, residue_set: ResidueSet) -> None: 

381 """Prepare a model for fine-tuning on a dataset with a new residue vocabulary. 

382 

383 Args: 

384 residue_set (ResidueSet): The residue vocabulary for the new dataset. 

385 """ 

386 # 1. Update residue set 

387 self.residue_set = residue_set 

388 

389 num_residues = len(self.residue_set) 

390 model_dim = self.config.dim 

391 

392 # 2. Update config 

393 self.config.vocab_size = num_residues 

394 

395 # 3. Update modules 

396 self.transition_model.char_embedding = nn.Embedding(num_embeddings=num_residues, embedding_dim=model_dim) 

397 self.transition_model.head[1] = nn.Linear(model_dim, num_residues) 

398 

399 def mixture_categorical( 

400 self, 

401 log_x: Float[ResidueLogProbabilities, "batch token"], 

402 log_alpha: float, 

403 log_alpha_complement: float, 

404 ) -> Float[ResidueLogProbabilities, "batch token"]: 

405 """A categorical mixture between a base distribution and a uniform distribution. 

406 

407 Args: 

408 log_x (torch.FloatTensor[..., num_classes]): 

409 The base distribution. 

410 

411 log_alpha (float): 

412 The log of the mixture weight. 

413 

414 log_alpha_complement (float): 

415 The log of 1 minus the mixture weight. 

416 

417 Returns: 

418 torch.FloatTensor[..., num_classes]: 

419 The log-probabilities of the mixture. 

420 """ 

421 return torch.logaddexp( 

422 log_x + log_alpha, 

423 log_alpha_complement - math.log(len(self.residue_set)), 

424 ) 

425 

426 def forward( 

427 self, 

428 log_x_t: Float[ResidueLogProbabilities, "batch token"], 

429 log_x_0: Float[ResidueLogProbabilities, "batch token"], 

430 t: Integer[TimeStep, " batch"], 

431 ) -> Float[ResidueLogProbabilities, "batch token"]: 

432 """Calculate the log-posterior of `t-1`-th process values given the 0-th and t-th values. 

433 

434 Args: 

435 log_x_t (torch.FloatTensor[batch_size, sequence_length, num_classes]): 

436 The log one-hot representation of the process values at the `t`-th time step. 

437 

438 log_x_0 (torch.FloatTensor[batch_size, sequence_length, num_classes]): 

439 The log one-hot representation of the process values at the `t`-th time step. 

440 t (int): 

441 The time step. 

442 

443 Returns: 

444 torch.FloatTensor[batch_size, sequence_length, num_classes]: 

445 The log-posterior probabilities of the process values at the `t-1`-th 

446 time step given the values at the 0-th and `t`-th time step 

447 i.e. q( x_{t-1} | x_{t}, x_0 ). 

448 """ 

449 log_prior = self.mixture_categorical( 

450 log_x=log_x_0, 

451 log_alpha=self.cumulative_schedule[t - 1].unsqueeze(-1).unsqueeze(-1), 

452 log_alpha_complement=self.cumulative_schedule_complement[t - 1].unsqueeze(-1).unsqueeze(-1), 

453 ) 

454 log_likelihood = self.mixture_categorical( 

455 log_x=log_x_t, 

456 log_alpha=self.diffusion_schedule[t].unsqueeze(-1).unsqueeze(-1), 

457 log_alpha_complement=self.diffusion_schedule_complement[t].unsqueeze(-1).unsqueeze(-1), 

458 ) 

459 t_mask = (t == 0).unsqueeze(-1).unsqueeze(-1).expand_as(log_x_0) 

460 prior_term = torch.where(t_mask, log_x_0, log_prior) 

461 logits = log_likelihood + prior_term 

462 return torch.log_softmax(logits, -1) 

463 

464 def reverse_distribution( 

465 self, 

466 x_t: Integer[Peptide, "batch token"], 

467 time: Integer[TimeStep, " batch"], 

468 **kwargs: dict, 

469 ) -> Float[ResidueLogProbabilities, "batch token"]: 

470 """Calculate the reverse transition distribution of the diffusion process. 

471 

472 Args: 

473 x_t (torch.LongTensor[batch_size, sequence_length]): 

474 The values at the `t`-th time step of the reverse process. 

475 

476 time (int): 

477 The time step. 

478 

479 Returns: 

480 torch.FloatTensor[batch_size, sequence_length, num_classes]: 

481 The log-probabilities of values for the `t-1`-th time step given 

482 values at the `t`-th time step i.e. `log p( x_{t-1} | x_{t} )`. 

483 """ 

484 log_x_0 = log_softmax(self.transition_model(x_t, t=time, **kwargs), -1) 

485 return self.forward(log_x_t=torch.log(one_hot(x_t, len(self.residue_set))), log_x_0=log_x_0, t=time) 

486 

487 

488class DiffusionLoss(nn.Module): 

489 """Holds logic for calculating the diffusion loss. 

490 

491 Args: 

492 model (InstaNovoPlus): 

493 The multinomial diffusion class. 

494 """ 

495 

496 def __init__(self, model: InstaNovoPlus) -> None: 

497 super().__init__() 

498 self.model = model 

499 

500 self.base_model = model.module if hasattr(model, "module") else model 

501 

502 self.time_steps = self.base_model.time_steps 

503 

504 @staticmethod 

505 def kl_divergence( 

506 log_probs_first: Float[ResidueLogProbabilities, "..."], 

507 log_probs_second: Float[ResidueLogProbabilities, "..."], 

508 ) -> Float[torch.Tensor, "..."]: 

509 """Calculate the Kullback-Liebler divergence between two multinomial distributions. 

510 

511 Args: 

512 log_probs_first (torch.FloatTensor[..., num_classes]): 

513 The log-probabilities of the base distribution. 

514 

515 log_probs_second (torch.FloatTensor[..., num_classes]): 

516 The log-probabilities of the comparison distribution. 

517 

518 Returns: 

519 torch.FloatTensor[1]: 

520 The KL-divergence averaged over all but the final dimension. 

521 """ 

522 return (torch.exp(log_probs_first) * (log_probs_first - log_probs_second)).sum(-1).sum(-1) 

523 

524 def forward(self, x_0: Integer[Peptide, "batch token"], **kwargs: dict) -> Float[torch.Tensor, "1"]: 

525 """Calculate a single Monte Carlo estimate of the multinomial diffusion loss (-ELBO). 

526 

527 Args: 

528 x_0 (torch.LongTensor[batch_size, sequence_length]): 

529 A batch of padded sequences. 

530 

531 Returns: 

532 torch.FloatTensor[1]: 

533 The loss estimate. 

534 """ 

535 # 1. Sample time step 

536 t = torch.randint(0, self.time_steps - 1, (x_0.shape[0],)).to(x_0.device) 

537 

538 # 2. Compute L_t 

539 loss = self._compute_loss(t=t, x_0=x_0, **kwargs).mean() 

540 

541 # 3. Calculate prior KL term 

542 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set))) 

543 final_log_probs = self.base_model.mixture_categorical( 

544 log_x=log_x_0, 

545 log_alpha=self.base_model.cumulative_schedule[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1), 

546 log_alpha_complement=self.base_model.cumulative_schedule_complement[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1), 

547 ) 

548 uniform_log_probs = torch.log(torch.ones_like(final_log_probs) / len(self.base_model.residue_set)) 

549 kl_loss = self.kl_divergence(final_log_probs, uniform_log_probs).mean() 

550 return loss + kl_loss 

551 

552 def _compute_loss( 

553 self, 

554 x_0: Integer[Peptide, "batch token"], 

555 t: Integer[TimeStep, " batch"], 

556 **kwargs: dict, 

557 ) -> Float[torch.Tensor, " batch"]: 

558 # 1. sample x_{t+1} 

559 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set))) 

560 log_probs = self.base_model.mixture_categorical( 

561 log_x=log_x_0, 

562 log_alpha=self.base_model.cumulative_schedule[t].unsqueeze(-1).unsqueeze(-1), 

563 log_alpha_complement=self.base_model.cumulative_schedule_complement[t].unsqueeze(-1).unsqueeze(-1), 

564 ) 

565 x_next = Categorical(logits=log_probs).sample() 

566 

567 # 2. Calculate loss 

568 log_dist = self.base_model.reverse_distribution(x_t=x_next, time=t, **kwargs) 

569 

570 nll_loss = -(one_hot(x_0, num_classes=len(self.base_model.residue_set)) * log_dist).sum(-1).sum(-1) 

571 

572 log_posterior = self.model(log_x_0=log_x_0, log_x_t=torch.log(one_hot(x_next, log_probs.size(-1))), t=t) 

573 denoising_loss = self.kl_divergence(log_posterior, log_dist) 

574 loss = torch.where(t == 0, nll_loss, denoising_loss) 

575 return loss