Coverage for instanovo/diffusion/multinomial_diffusion.py: 48%
222 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import json
4import math
5import os
6import shutil
7from importlib import resources
8from pathlib import Path
9from typing import Tuple
10from urllib.parse import urlsplit
12import torch
13from jaxtyping import Float, Integer
14from omegaconf import DictConfig, OmegaConf, open_dict
15from torch import nn
16from torch.distributions import Categorical
17from torch.nn.functional import log_softmax, one_hot
19from instanovo.__init__ import console
20from instanovo.diffusion.model import MassSpectrumTransFusion
21from instanovo.types import Peptide, ResidueLogProbabilities, TimeStep
22from instanovo.utils.colorlogging import ColorLog
23from instanovo.utils.file_downloader import download_file
24from instanovo.utils.residues import ResidueSet
25from instanovo.utils.s3 import S3FileHandler
27MODEL_TYPE = "diffusion"
29logger = ColorLog(console, __name__).logger
32def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Float[torch.Tensor, " time"]:
33 """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672 .
35 Returns alpha parameters, NOT Beta
36 """
37 steps = timesteps + 1
38 x = torch.linspace(0, timesteps, steps)
39 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
40 alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
41 alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
42 alphas = torch.clamp(alphas, 0.001, 1.0)
43 return torch.sqrt(alphas)
46class InstaNovoPlus(nn.Module):
47 r"""This class implements Multinomial Diffusion as described in Hoogeboom et al. 2021.
49 Args:
50 config (omegaconf.DictConfig):
51 The model configuration. This should have keys:
52 - 'name': the model name identifier.
53 - 'time_steps': the number of time steps in the diffusion process
54 - 'max_length': the maximum sequence for the model
55 - 'device': the device where the `Pytorch` model should be
56 loaded e.g. `cpu`, `cuda:0` etc.
57 - 'vocab_size': the number of residues in the vocabulary
58 - 'transition_model': the `DictConfig` for the transition model
60 This information is necessary for saving and loading the model.
62 transition_model (nn.Module):
63 The model that predictions the initial sequence given
64 the sequence sampled the current time step and the
65 sequence sampled the previous time step. This is
66 just a sequence tagging model.
68 diffusion_schedule (torch.FloatTensor[time_steps]):
69 The sequence of diffusion probabilities. Note
70 that `diffusion_schedule[t]` is \alpha_t in
71 the paper's terminology, not \beta_t.
73 residue_set (ResidueSet):
74 The residue vocabulary. This holds a mapping between
75 residues and indices and residue masses.
76 """
78 config_path: str
79 schedule_path: str
80 checkpoint_path: str
82 def __init__(
83 self,
84 config: DictConfig,
85 transition_model: nn.Module,
86 diffusion_schedule: Float[torch.Tensor, " time"],
87 residue_set: ResidueSet,
88 ) -> None:
89 super().__init__()
90 self.config = config
91 self.time_steps = config.time_steps
92 self.residue_set = residue_set
93 self.transition_model = transition_model
94 self.register_buffer("diffusion_schedule", torch.log(diffusion_schedule))
95 self.register_buffer("diffusion_schedule_complement", torch.log(1 - diffusion_schedule))
96 self.register_buffer("cumulative_schedule", torch.cumsum(self.diffusion_schedule, -1))
97 self.register_buffer(
98 "cumulative_schedule_complement",
99 torch.log(1 - torch.exp(self.cumulative_schedule)),
100 )
102 def save(
103 self,
104 path: str,
105 ckpt_details: str,
106 overwrite: bool = False,
107 temp_dir: str | None = None,
108 use_legacy_format: bool = False,
109 ) -> None:
110 """Save the model to a directory.
112 Args:
113 path (str):
114 Path to the base directory where the model is saved.
115 The model is saved in a subdirectory with the model's
116 name identifier.
118 ckpt_details (str):
119 Additional checkpoint details to include in model save directory.
121 overwrite (bool, optional):
122 Whether to overwrite the directory if one already exists
123 for the model. Defaults to False.
125 temp_dir (str | None, optional):
126 Temporary directory to save intermediate files to.
127 Defaults to None.
129 use_legacy_format (bool, optional):
130 Whether to save the model in the legacy folder format.
131 If False, saves as a single file. Defaults to False.
133 Raises:
134 FileExistsError: If `overwrite` is `False` and a directory already exists
135 for the model identifier.
136 """
137 model_dir = os.path.join(path, ckpt_details)
139 def save_file_local(filename: str, content: str) -> None:
140 """Save a file locally (no upload)."""
141 return
143 def save_file_s3(filename: str, content: str) -> None:
144 """Upload a file to S3."""
145 # TODO: fix this
146 s3 = S3FileHandler()
147 return s3.upload( # type: ignore
148 content, s3.convert_to_s3_output(model_dir + "/" + filename)
149 )
151 if temp_dir is None:
152 if os.path.exists(model_dir) and os.path.isdir(model_dir):
153 if overwrite:
154 shutil.rmtree(model_dir)
155 else:
156 raise FileExistsError
158 if use_legacy_format:
159 os.makedirs(model_dir, exist_ok=True)
160 elif os.path.dirname(model_dir):
161 os.makedirs(os.path.dirname(model_dir), exist_ok=True)
163 save_path = model_dir
164 save_file = save_file_local
166 else:
167 save_path = temp_dir
168 save_file = save_file_s3
170 if use_legacy_format:
171 # Save model as a folder
172 # Save config
173 config_path = os.path.join(save_path, "config.yaml")
174 OmegaConf.save(config=self.config, f=config_path)
175 save_file("config.yaml", config_path)
177 # Save schedule
178 diff_schedule_path = os.path.join(save_path, "diffusion_schedule.pt")
179 torch.save(torch.exp(self.diffusion_schedule), diff_schedule_path)
180 save_file("diffusion_schedule.pt", diff_schedule_path)
182 # Save transition model
183 self.transition_model.to("cpu")
184 transition_model_path = os.path.join(save_path, "transition_model.ckpt")
185 torch.save(self.transition_model.state_dict(), transition_model_path)
186 save_file("transition_model.ckpt", transition_model_path)
187 else:
188 # Save model as a single file
189 transition_model_state = {k: v.cpu() for k, v in self.transition_model.state_dict().items()}
191 model_data = {
192 "config": OmegaConf.to_container(self.config),
193 "diffusion_schedule": torch.exp(self.diffusion_schedule).tolist(),
194 "transition_model": transition_model_state,
195 }
197 if temp_dir:
198 save_path = os.path.join(save_path, "instanovo_plus.ckpt")
199 torch.save(model_data, save_path)
200 save_file("instanovo_plus.ckpt", save_path)
201 else:
202 torch.save(model_data, save_path)
203 save_file("instanovo_plus.ckpt", save_path)
205 @classmethod
206 def load(cls, path: str, override_config: DictConfig | dict | None = None) -> Tuple[InstaNovoPlus, DictConfig]:
207 """Load a saved model.
209 Args:
210 path (str):
211 Path to model checkpoint file or directory where model is saved.
212 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None.
214 Returns:
215 (InstaNovoPlus, DictConfig): The loaded model and config.
217 """
218 is_legacy_format = False
219 if os.path.isdir(path):
220 # Load config
221 cls.config_path = os.path.join(path, "config.yaml")
222 config = OmegaConf.load(cls.config_path)
223 if override_config is not None:
224 with open_dict(config):
225 config.update(override_config)
227 cls.schedule_path = os.path.join(path, "diffusion_schedule.pt")
228 diffusion_schedule = torch.load(cls.schedule_path, map_location=torch.device("cpu"), weights_only=True)
230 cls.checkpoint_path = os.path.join(path, "transition_model.ckpt")
231 transition_model_state = torch.load(cls.checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
233 is_legacy_format = True
234 else:
235 # Load model from checkpoint file
236 try:
237 model_data = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
238 except Exception as e:
239 raise ValueError(f"Failed to load model from {path}: {str(e)}") from e
241 config = OmegaConf.create(model_data["config"])
242 if override_config is not None:
243 with open_dict(config):
244 config.update(override_config)
246 diffusion_schedule = torch.tensor(model_data["diffusion_schedule"])
247 if "transition_model" in model_data:
248 transition_model_state = model_data["transition_model"]
249 is_legacy_format = True
250 else:
251 transition_model_state = model_data["state_dict"]
253 if is_legacy_format:
254 # Load residues
255 residue_set = ResidueSet(
256 residue_masses=config["residues"],
257 residue_remapping=config["residue_remapping"],
258 )
260 # Load transition model
261 transition_model = MassSpectrumTransFusion(
262 config,
263 config.max_length,
264 )
265 transition_model.load_state_dict(transition_model_state)
267 return cls(
268 config=config,
269 transition_model=transition_model,
270 diffusion_schedule=diffusion_schedule,
271 residue_set=residue_set,
272 ), config
273 else:
274 residues = model_data["residues"]
275 residue_set = ResidueSet(
276 residue_masses=residues,
277 )
278 transition_model = MassSpectrumTransFusion(
279 config,
280 config.max_length,
281 )
283 model = cls(
284 config=config,
285 transition_model=transition_model,
286 diffusion_schedule=diffusion_schedule,
287 residue_set=residue_set,
288 )
289 model.load_state_dict(model_data["state_dict"])
291 return model, config
293 @staticmethod
294 def get_pretrained() -> list[str]:
295 """Get a list of pretrained model ids."""
296 # Load the models.json file
297 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
298 models_config = json.load(f)
300 if MODEL_TYPE not in models_config:
301 return []
303 return list(models_config[MODEL_TYPE].keys())
305 @classmethod
306 def from_pretrained(cls, model_id: str, override_config: DictConfig | dict | None = None) -> Tuple["InstaNovoPlus", "DictConfig"]:
307 """Download and load by model id or model path."""
308 # Check if model_id is a local directory
309 expected_files = ["config.yaml", "diffusion_schedule.pt", "transition_model.ckpt"]
310 if os.path.isdir(model_id):
311 if all(os.path.exists(os.path.join(model_id, fn)) for fn in expected_files):
312 return cls.load(model_id, override_config=override_config)
313 else:
314 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(model_id, fn))]
315 raise FileNotFoundError(f"InstaNovo+ model directory {model_id} is missing the expected file(s): {', '.join(missing_files)}.")
316 elif os.path.exists(model_id):
317 return cls.load(model_id, override_config=override_config)
319 # Load the models.json file
320 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
321 models_config = json.load(f)
323 # Find the model in the config
324 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]:
325 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]")
327 # Create cache directory if it doesn't exist
328 cache_dir = Path.home() / ".cache" / "instanovo"
329 cache_dir.mkdir(parents=True, exist_ok=True)
331 model_info = models_config[MODEL_TYPE][model_id]
333 if "remote" in model_info:
334 url = model_info["remote"]
336 # Generate a filename for the cached model
337 file_name = urlsplit(url).path.split("/")[-1]
338 cached_file = cache_dir / file_name
340 # Check if the file is already cached
341 if not cached_file.exists():
342 download_file(url, cached_file, model_id, file_name)
344 else:
345 logger.info(f"Model {model_id} already cached at {cached_file}")
347 try:
348 # Load and return the model
349 logger.info(f"Loading model {model_id} (remote)")
350 return cls.load(str(cached_file), override_config=override_config)
351 except Exception as e:
352 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}")
353 if cached_file.exists():
354 cached_file.unlink()
356 download_file(url, cached_file, model_id, file_name)
357 logger.info(f"Loading newly downloaded model {model_id}")
358 return cls.load(str(cached_file), override_config=override_config)
360 elif "local" in model_info:
361 instanovo_plus_model = model_info["local"]
362 if os.path.isdir(instanovo_plus_model):
363 if all(os.path.exists(os.path.join(instanovo_plus_model, fn)) for fn in expected_files):
364 logger.info(f"Loading model {model_id} (local)")
365 return cls.load(instanovo_plus_model, override_config=override_config)
366 else:
367 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(instanovo_plus_model, fn))]
368 raise FileNotFoundError(
369 f"InstaNovo+ model directory {instanovo_plus_model} is missing the expected file(s): {', '.join(missing_files)}."
370 )
371 elif os.path.exists(instanovo_plus_model):
372 return cls.load(instanovo_plus_model, override_config=override_config)
373 else:
374 raise ValueError(
375 f"Local model path '{instanovo_plus_model}' must exist, be a directory and containing the files {', '.join(expected_files)}."
376 )
377 else:
378 raise ValueError(f"Model {model_id} does not have a valid 'remote', 'local' entry in models.json")
380 def prepare_fine_tuning(self, residue_set: ResidueSet) -> None:
381 """Prepare a model for fine-tuning on a dataset with a new residue vocabulary.
383 Args:
384 residue_set (ResidueSet): The residue vocabulary for the new dataset.
385 """
386 # 1. Update residue set
387 self.residue_set = residue_set
389 num_residues = len(self.residue_set)
390 model_dim = self.config.dim
392 # 2. Update config
393 self.config.vocab_size = num_residues
395 # 3. Update modules
396 self.transition_model.char_embedding = nn.Embedding(num_embeddings=num_residues, embedding_dim=model_dim)
397 self.transition_model.head[1] = nn.Linear(model_dim, num_residues)
399 def mixture_categorical(
400 self,
401 log_x: Float[ResidueLogProbabilities, "batch token"],
402 log_alpha: float,
403 log_alpha_complement: float,
404 ) -> Float[ResidueLogProbabilities, "batch token"]:
405 """A categorical mixture between a base distribution and a uniform distribution.
407 Args:
408 log_x (torch.FloatTensor[..., num_classes]):
409 The base distribution.
411 log_alpha (float):
412 The log of the mixture weight.
414 log_alpha_complement (float):
415 The log of 1 minus the mixture weight.
417 Returns:
418 torch.FloatTensor[..., num_classes]:
419 The log-probabilities of the mixture.
420 """
421 return torch.logaddexp(
422 log_x + log_alpha,
423 log_alpha_complement - math.log(len(self.residue_set)),
424 )
426 def forward(
427 self,
428 log_x_t: Float[ResidueLogProbabilities, "batch token"],
429 log_x_0: Float[ResidueLogProbabilities, "batch token"],
430 t: Integer[TimeStep, " batch"],
431 ) -> Float[ResidueLogProbabilities, "batch token"]:
432 """Calculate the log-posterior of `t-1`-th process values given the 0-th and t-th values.
434 Args:
435 log_x_t (torch.FloatTensor[batch_size, sequence_length, num_classes]):
436 The log one-hot representation of the process values at the `t`-th time step.
438 log_x_0 (torch.FloatTensor[batch_size, sequence_length, num_classes]):
439 The log one-hot representation of the process values at the `t`-th time step.
440 t (int):
441 The time step.
443 Returns:
444 torch.FloatTensor[batch_size, sequence_length, num_classes]:
445 The log-posterior probabilities of the process values at the `t-1`-th
446 time step given the values at the 0-th and `t`-th time step
447 i.e. q( x_{t-1} | x_{t}, x_0 ).
448 """
449 log_prior = self.mixture_categorical(
450 log_x=log_x_0,
451 log_alpha=self.cumulative_schedule[t - 1].unsqueeze(-1).unsqueeze(-1),
452 log_alpha_complement=self.cumulative_schedule_complement[t - 1].unsqueeze(-1).unsqueeze(-1),
453 )
454 log_likelihood = self.mixture_categorical(
455 log_x=log_x_t,
456 log_alpha=self.diffusion_schedule[t].unsqueeze(-1).unsqueeze(-1),
457 log_alpha_complement=self.diffusion_schedule_complement[t].unsqueeze(-1).unsqueeze(-1),
458 )
459 t_mask = (t == 0).unsqueeze(-1).unsqueeze(-1).expand_as(log_x_0)
460 prior_term = torch.where(t_mask, log_x_0, log_prior)
461 logits = log_likelihood + prior_term
462 return torch.log_softmax(logits, -1)
464 def reverse_distribution(
465 self,
466 x_t: Integer[Peptide, "batch token"],
467 time: Integer[TimeStep, " batch"],
468 **kwargs: dict,
469 ) -> Float[ResidueLogProbabilities, "batch token"]:
470 """Calculate the reverse transition distribution of the diffusion process.
472 Args:
473 x_t (torch.LongTensor[batch_size, sequence_length]):
474 The values at the `t`-th time step of the reverse process.
476 time (int):
477 The time step.
479 Returns:
480 torch.FloatTensor[batch_size, sequence_length, num_classes]:
481 The log-probabilities of values for the `t-1`-th time step given
482 values at the `t`-th time step i.e. `log p( x_{t-1} | x_{t} )`.
483 """
484 log_x_0 = log_softmax(self.transition_model(x_t, t=time, **kwargs), -1)
485 return self.forward(log_x_t=torch.log(one_hot(x_t, len(self.residue_set))), log_x_0=log_x_0, t=time)
488class DiffusionLoss(nn.Module):
489 """Holds logic for calculating the diffusion loss.
491 Args:
492 model (InstaNovoPlus):
493 The multinomial diffusion class.
494 """
496 def __init__(self, model: InstaNovoPlus) -> None:
497 super().__init__()
498 self.model = model
500 self.base_model = model.module if hasattr(model, "module") else model
502 self.time_steps = self.base_model.time_steps
504 @staticmethod
505 def kl_divergence(
506 log_probs_first: Float[ResidueLogProbabilities, "..."],
507 log_probs_second: Float[ResidueLogProbabilities, "..."],
508 ) -> Float[torch.Tensor, "..."]:
509 """Calculate the Kullback-Liebler divergence between two multinomial distributions.
511 Args:
512 log_probs_first (torch.FloatTensor[..., num_classes]):
513 The log-probabilities of the base distribution.
515 log_probs_second (torch.FloatTensor[..., num_classes]):
516 The log-probabilities of the comparison distribution.
518 Returns:
519 torch.FloatTensor[1]:
520 The KL-divergence averaged over all but the final dimension.
521 """
522 return (torch.exp(log_probs_first) * (log_probs_first - log_probs_second)).sum(-1).sum(-1)
524 def forward(self, x_0: Integer[Peptide, "batch token"], **kwargs: dict) -> Float[torch.Tensor, "1"]:
525 """Calculate a single Monte Carlo estimate of the multinomial diffusion loss (-ELBO).
527 Args:
528 x_0 (torch.LongTensor[batch_size, sequence_length]):
529 A batch of padded sequences.
531 Returns:
532 torch.FloatTensor[1]:
533 The loss estimate.
534 """
535 # 1. Sample time step
536 t = torch.randint(0, self.time_steps - 1, (x_0.shape[0],)).to(x_0.device)
538 # 2. Compute L_t
539 loss = self._compute_loss(t=t, x_0=x_0, **kwargs).mean()
541 # 3. Calculate prior KL term
542 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set)))
543 final_log_probs = self.base_model.mixture_categorical(
544 log_x=log_x_0,
545 log_alpha=self.base_model.cumulative_schedule[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1),
546 log_alpha_complement=self.base_model.cumulative_schedule_complement[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1),
547 )
548 uniform_log_probs = torch.log(torch.ones_like(final_log_probs) / len(self.base_model.residue_set))
549 kl_loss = self.kl_divergence(final_log_probs, uniform_log_probs).mean()
550 return loss + kl_loss
552 def _compute_loss(
553 self,
554 x_0: Integer[Peptide, "batch token"],
555 t: Integer[TimeStep, " batch"],
556 **kwargs: dict,
557 ) -> Float[torch.Tensor, " batch"]:
558 # 1. sample x_{t+1}
559 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set)))
560 log_probs = self.base_model.mixture_categorical(
561 log_x=log_x_0,
562 log_alpha=self.base_model.cumulative_schedule[t].unsqueeze(-1).unsqueeze(-1),
563 log_alpha_complement=self.base_model.cumulative_schedule_complement[t].unsqueeze(-1).unsqueeze(-1),
564 )
565 x_next = Categorical(logits=log_probs).sample()
567 # 2. Calculate loss
568 log_dist = self.base_model.reverse_distribution(x_t=x_next, time=t, **kwargs)
570 nll_loss = -(one_hot(x_0, num_classes=len(self.base_model.residue_set)) * log_dist).sum(-1).sum(-1)
572 log_posterior = self.model(log_x_0=log_x_0, log_x_t=torch.log(one_hot(x_next, log_probs.size(-1))), t=t)
573 denoising_loss = self.kl_divergence(log_posterior, log_dist)
574 loss = torch.where(t == 0, nll_loss, denoising_loss)
575 return loss