Coverage for instanovo/diffusion/multinomial_diffusion.py: 50%
230 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-06-08 23:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-06-08 23:00 +0000
1from __future__ import annotations
3import json
4import math
5import os
6import shutil
7from importlib import resources
8from pathlib import Path
9from typing import Tuple
10from urllib.parse import urlsplit
12import torch
13from jaxtyping import Float, Integer
14from omegaconf import DictConfig, OmegaConf, open_dict
15from torch import nn
16from torch.distributions import Categorical
17from torch.nn.functional import log_softmax, one_hot
19from instanovo.__init__ import console
20from instanovo.diffusion.model import MassSpectrumTransFusion
21from instanovo.types import Peptide, ResidueLogProbabilities, TimeStep
22from instanovo.utils.colorlogging import ColorLog
23from instanovo.utils.file_downloader import download_file
24from instanovo.utils.residues import ResidueSet
25from instanovo.utils.s3 import S3FileHandler
27MODEL_TYPE = "diffusion"
29logger = ColorLog(console, __name__).logger
32def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Float[torch.Tensor, " time"]:
33 """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672 .
35 Returns alpha parameters, NOT Beta
36 """
37 steps = timesteps + 1
38 x = torch.linspace(0, timesteps, steps)
39 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
40 alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
41 alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
42 alphas = torch.clamp(alphas, 0.001, 1.0)
43 return torch.sqrt(alphas)
46class InstaNovoPlus(nn.Module):
47 r"""This class implements Multinomial Diffusion as described in Hoogeboom et al. 2021.
49 Args:
50 config (omegaconf.DictConfig):
51 The model configuration. This should have keys:
52 - 'name': the model name identifier.
53 - 'time_steps': the number of time steps in the diffusion process
54 - 'max_length': the maximum sequence for the model
55 - 'device': the device where the `Pytorch` model should be
56 loaded e.g. `cpu`, `cuda:0` etc.
57 - 'vocab_size': the number of residues in the vocabulary
58 - 'transition_model': the `DictConfig` for the transition model
60 This information is necessary for saving and loading the model.
62 transition_model (nn.Module):
63 The model that predictions the initial sequence given
64 the sequence sampled the current time step and the
65 sequence sampled the previous time step. This is
66 just a sequence tagging model.
68 diffusion_schedule (torch.FloatTensor[time_steps]):
69 The sequence of diffusion probabilities. Note
70 that `diffusion_schedule[t]` is \alpha_t in
71 the paper's terminology, not \beta_t.
73 residue_set (ResidueSet):
74 The residue vocabulary. This holds a mapping between
75 residues and indices and residue masses.
76 """
78 config_path: str
79 schedule_path: str
80 checkpoint_path: str
82 def __init__(
83 self,
84 config: DictConfig,
85 transition_model: nn.Module,
86 diffusion_schedule: Float[torch.Tensor, " time"],
87 residue_set: ResidueSet,
88 ) -> None:
89 super().__init__()
90 self.config = config
91 self.time_steps = config.time_steps
92 self.residue_set = residue_set
93 self.transition_model = transition_model
94 self.register_buffer("diffusion_schedule", torch.log(diffusion_schedule))
95 self.register_buffer("diffusion_schedule_complement", torch.log(1 - diffusion_schedule))
96 self.register_buffer("cumulative_schedule", torch.cumsum(self.diffusion_schedule, -1))
97 self.register_buffer(
98 "cumulative_schedule_complement",
99 torch.log(1 - torch.exp(self.cumulative_schedule)),
100 )
102 def save(
103 self,
104 path: str,
105 ckpt_details: str,
106 overwrite: bool = False,
107 temp_dir: str | None = None,
108 use_legacy_format: bool = False,
109 ) -> None:
110 """Save the model to a directory.
112 Args:
113 path (str):
114 Path to the base directory where the model is saved.
115 The model is saved in a subdirectory with the model's
116 name identifier.
118 ckpt_details (str):
119 Additional checkpoint details to include in model save directory.
121 overwrite (bool, optional):
122 Whether to overwrite the directory if one already exists
123 for the model. Defaults to False.
125 temp_dir (str | None, optional):
126 Temporary directory to save intermediate files to.
127 Defaults to None.
129 use_legacy_format (bool, optional):
130 Whether to save the model in the legacy folder format.
131 If False, saves as a single file. Defaults to False.
133 Raises:
134 FileExistsError: If `overwrite` is `False` and a directory already exists
135 for the model identifier.
136 """
137 model_dir = os.path.join(path, ckpt_details)
139 def save_file_local(filename: str, content: str) -> None:
140 """Save a file locally (no upload)."""
141 return
143 def save_file_s3(filename: str, content: str) -> None:
144 """Upload a file to S3."""
145 # TODO: fix this
146 s3 = S3FileHandler()
147 return s3.upload( # type: ignore
148 content, s3.convert_to_s3_output(model_dir + "/" + filename)
149 )
151 if temp_dir is None:
152 if os.path.exists(model_dir) and os.path.isdir(model_dir):
153 if overwrite:
154 shutil.rmtree(model_dir)
155 else:
156 raise FileExistsError
158 if use_legacy_format:
159 os.makedirs(model_dir, exist_ok=True)
160 elif os.path.dirname(model_dir):
161 os.makedirs(os.path.dirname(model_dir), exist_ok=True)
163 save_path = model_dir
164 save_file = save_file_local
166 else:
167 save_path = temp_dir
168 save_file = save_file_s3
170 if use_legacy_format:
171 # Save model as a folder
172 # Save config
173 config_path = os.path.join(save_path, "config.yaml")
174 OmegaConf.save(config=self.config, f=config_path)
175 save_file("config.yaml", config_path)
177 # Save schedule
178 diff_schedule_path = os.path.join(save_path, "diffusion_schedule.pt")
179 torch.save(torch.exp(self.diffusion_schedule), diff_schedule_path)
180 save_file("diffusion_schedule.pt", diff_schedule_path)
182 # Save transition model
183 self.transition_model.to("cpu")
184 transition_model_path = os.path.join(save_path, "transition_model.ckpt")
185 torch.save(self.transition_model.state_dict(), transition_model_path)
186 save_file("transition_model.ckpt", transition_model_path)
187 else:
188 # Save model as a single file
189 transition_model_state = {k: v.cpu() for k, v in self.transition_model.state_dict().items()}
191 model_data = {
192 "config": OmegaConf.to_container(self.config),
193 "diffusion_schedule": torch.exp(self.diffusion_schedule).tolist(),
194 "transition_model": transition_model_state,
195 }
197 if temp_dir:
198 save_path = os.path.join(save_path, "instanovo_plus.ckpt")
199 torch.save(model_data, save_path)
200 save_file("instanovo_plus.ckpt", save_path)
201 else:
202 torch.save(model_data, save_path)
203 save_file("instanovo_plus.ckpt", save_path)
205 @classmethod
206 def load(cls, path: str, override_config: DictConfig | dict | None = None) -> Tuple[InstaNovoPlus, DictConfig]:
207 """Load a saved model.
209 Args:
210 path (str):
211 Path to model checkpoint file or directory where model is saved.
212 override_config (DictConfig | dict | None): Optional override config values with a DictConfig or dict, defaults to None.
214 Returns:
215 (InstaNovoPlus, DictConfig): The loaded model and config.
217 """
218 is_legacy_format = False
219 if os.path.isdir(path):
220 # Load config
221 cls.config_path = os.path.join(path, "config.yaml")
222 config = OmegaConf.load(cls.config_path)
223 if override_config is not None:
224 with open_dict(config):
225 config.update(override_config)
227 cls.schedule_path = os.path.join(path, "diffusion_schedule.pt")
228 diffusion_schedule = torch.load(cls.schedule_path, map_location=torch.device("cpu"), weights_only=True)
230 cls.checkpoint_path = os.path.join(path, "transition_model.ckpt")
231 transition_model_state = torch.load(cls.checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
233 is_legacy_format = True
234 else:
235 # Load model from checkpoint file
236 try:
237 _whitelist_torch_omegaconf()
238 model_data = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
239 except Exception as e:
240 raise ValueError(f"Failed to load model from {path}: {str(e)}") from e
242 config = OmegaConf.create(model_data["config"])
243 if override_config is not None:
244 with open_dict(config):
245 config.update(override_config)
247 diffusion_schedule = torch.tensor(model_data["diffusion_schedule"])
248 if "transition_model" in model_data:
249 transition_model_state = model_data["transition_model"]
250 is_legacy_format = True
251 else:
252 transition_model_state = model_data["state_dict"]
254 if is_legacy_format:
255 # Load residues
256 residue_set = ResidueSet(
257 residue_masses=config["residues"],
258 residue_remapping=config["residue_remapping"],
259 )
261 # Load transition model
262 transition_model = MassSpectrumTransFusion(
263 config,
264 config.max_length,
265 )
266 transition_model.load_state_dict(transition_model_state)
268 return cls(
269 config=config,
270 transition_model=transition_model,
271 diffusion_schedule=diffusion_schedule,
272 residue_set=residue_set,
273 ), config
274 else:
275 residues = model_data["residues"]
276 residue_set = ResidueSet(
277 residue_masses=residues,
278 )
279 transition_model = MassSpectrumTransFusion(
280 config,
281 config.max_length,
282 )
284 model = cls(
285 config=config,
286 transition_model=transition_model,
287 diffusion_schedule=diffusion_schedule,
288 residue_set=residue_set,
289 )
290 model.load_state_dict(model_data["state_dict"])
292 return model, config
294 @staticmethod
295 def get_pretrained() -> list[str]:
296 """Get a list of pretrained model ids."""
297 # Load the models.json file
298 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
299 models_config = json.load(f)
301 if MODEL_TYPE not in models_config:
302 return []
304 return list(models_config[MODEL_TYPE].keys())
306 @classmethod
307 def from_pretrained(cls, model_id: str, override_config: DictConfig | dict | None = None) -> Tuple["InstaNovoPlus", "DictConfig"]:
308 """Download and load by model id or model path."""
309 # Check if model_id is a local directory
310 expected_files = ["config.yaml", "diffusion_schedule.pt", "transition_model.ckpt"]
311 if os.path.isdir(model_id):
312 if all(os.path.exists(os.path.join(model_id, fn)) for fn in expected_files):
313 return cls.load(model_id, override_config=override_config)
314 else:
315 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(model_id, fn))]
316 raise FileNotFoundError(f"InstaNovo+ model directory {model_id} is missing the expected file(s): {', '.join(missing_files)}.")
317 elif os.path.exists(model_id):
318 return cls.load(model_id, override_config=override_config)
320 # Load the models.json file
321 with resources.files("instanovo").joinpath("models.json").open("r", encoding="utf-8") as f:
322 models_config = json.load(f)
324 # Find the model in the config
325 if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]:
326 raise ValueError(f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]")
328 # Create cache directory if it doesn't exist
329 cache_dir = Path.home() / ".cache" / "instanovo"
330 cache_dir.mkdir(parents=True, exist_ok=True)
332 model_info = models_config[MODEL_TYPE][model_id]
334 if "remote" in model_info:
335 url = model_info["remote"]
337 # Generate a filename for the cached model
338 file_name = urlsplit(url).path.split("/")[-1]
339 cached_file = cache_dir / file_name
341 # Check if the file is already cached
342 if not cached_file.exists():
343 download_file(url, cached_file, model_id, file_name)
345 else:
346 logger.info(f"Model {model_id} already cached at {cached_file}")
348 try:
349 # Load and return the model
350 logger.info(f"Loading model {model_id} (remote)")
351 return cls.load(str(cached_file), override_config=override_config)
352 except Exception as e:
353 logger.warning(f"Failed to load cached model {model_id}, it may be corrupted. Deleting and re-downloading. Error: {e}")
354 if cached_file.exists():
355 cached_file.unlink()
357 download_file(url, cached_file, model_id, file_name)
358 logger.info(f"Loading newly downloaded model {model_id}")
359 return cls.load(str(cached_file), override_config=override_config)
361 elif "local" in model_info:
362 instanovo_plus_model = model_info["local"]
363 if os.path.isdir(instanovo_plus_model):
364 if all(os.path.exists(os.path.join(instanovo_plus_model, fn)) for fn in expected_files):
365 logger.info(f"Loading model {model_id} (local)")
366 return cls.load(instanovo_plus_model, override_config=override_config)
367 else:
368 missing_files = [fn for fn in expected_files if not os.path.exists(os.path.join(instanovo_plus_model, fn))]
369 raise FileNotFoundError(
370 f"InstaNovo+ model directory {instanovo_plus_model} is missing the expected file(s): {', '.join(missing_files)}."
371 )
372 elif os.path.exists(instanovo_plus_model):
373 return cls.load(instanovo_plus_model, override_config=override_config)
374 else:
375 raise ValueError(
376 f"Local model path '{instanovo_plus_model}' must exist, be a directory and containing the files {', '.join(expected_files)}."
377 )
378 else:
379 raise ValueError(f"Model {model_id} does not have a valid 'remote', 'local' entry in models.json")
381 def prepare_fine_tuning(self, residue_set: ResidueSet) -> None:
382 """Prepare a model for fine-tuning on a dataset with a new residue vocabulary.
384 Args:
385 residue_set (ResidueSet): The residue vocabulary for the new dataset.
386 """
387 # 1. Update residue set
388 self.residue_set = residue_set
390 num_residues = len(self.residue_set)
391 model_dim = self.config.dim
393 # 2. Update config
394 self.config.vocab_size = num_residues
396 # 3. Update modules
397 self.transition_model.char_embedding = nn.Embedding(num_embeddings=num_residues, embedding_dim=model_dim)
398 self.transition_model.head[1] = nn.Linear(model_dim, num_residues)
400 def mixture_categorical(
401 self,
402 log_x: Float[ResidueLogProbabilities, "batch token"],
403 log_alpha: float,
404 log_alpha_complement: float,
405 ) -> Float[ResidueLogProbabilities, "batch token"]:
406 """A categorical mixture between a base distribution and a uniform distribution.
408 Args:
409 log_x (torch.FloatTensor[..., num_classes]):
410 The base distribution.
412 log_alpha (float):
413 The log of the mixture weight.
415 log_alpha_complement (float):
416 The log of 1 minus the mixture weight.
418 Returns:
419 torch.FloatTensor[..., num_classes]:
420 The log-probabilities of the mixture.
421 """
422 return torch.logaddexp(
423 log_x + log_alpha,
424 log_alpha_complement - math.log(len(self.residue_set)),
425 )
427 def forward(
428 self,
429 log_x_t: Float[ResidueLogProbabilities, "batch token"],
430 log_x_0: Float[ResidueLogProbabilities, "batch token"],
431 t: Integer[TimeStep, " batch"],
432 ) -> Float[ResidueLogProbabilities, "batch token"]:
433 """Calculate the log-posterior of `t-1`-th process values given the 0-th and t-th values.
435 Args:
436 log_x_t (torch.FloatTensor[batch_size, sequence_length, num_classes]):
437 The log one-hot representation of the process values at the `t`-th time step.
439 log_x_0 (torch.FloatTensor[batch_size, sequence_length, num_classes]):
440 The log one-hot representation of the process values at the `t`-th time step.
441 t (int):
442 The time step.
444 Returns:
445 torch.FloatTensor[batch_size, sequence_length, num_classes]:
446 The log-posterior probabilities of the process values at the `t-1`-th
447 time step given the values at the 0-th and `t`-th time step
448 i.e. q( x_{t-1} | x_{t}, x_0 ).
449 """
450 log_prior = self.mixture_categorical(
451 log_x=log_x_0,
452 log_alpha=self.cumulative_schedule[t - 1].unsqueeze(-1).unsqueeze(-1),
453 log_alpha_complement=self.cumulative_schedule_complement[t - 1].unsqueeze(-1).unsqueeze(-1),
454 )
455 log_likelihood = self.mixture_categorical(
456 log_x=log_x_t,
457 log_alpha=self.diffusion_schedule[t].unsqueeze(-1).unsqueeze(-1),
458 log_alpha_complement=self.diffusion_schedule_complement[t].unsqueeze(-1).unsqueeze(-1),
459 )
460 t_mask = (t == 0).unsqueeze(-1).unsqueeze(-1).expand_as(log_x_0)
461 prior_term = torch.where(t_mask, log_x_0, log_prior)
462 logits = log_likelihood + prior_term
463 return torch.log_softmax(logits, -1)
465 def reverse_distribution(
466 self,
467 x_t: Integer[Peptide, "batch token"],
468 time: Integer[TimeStep, " batch"],
469 **kwargs: dict,
470 ) -> Float[ResidueLogProbabilities, "batch token"]:
471 """Calculate the reverse transition distribution of the diffusion process.
473 Args:
474 x_t (torch.LongTensor[batch_size, sequence_length]):
475 The values at the `t`-th time step of the reverse process.
477 time (int):
478 The time step.
480 Returns:
481 torch.FloatTensor[batch_size, sequence_length, num_classes]:
482 The log-probabilities of values for the `t-1`-th time step given
483 values at the `t`-th time step i.e. `log p( x_{t-1} | x_{t} )`.
484 """
485 log_x_0 = log_softmax(self.transition_model(x_t, t=time, **kwargs), -1)
486 return self.forward(log_x_t=torch.log(one_hot(x_t, len(self.residue_set))), log_x_0=log_x_0, t=time)
489class DiffusionLoss(nn.Module):
490 """Holds logic for calculating the diffusion loss.
492 Args:
493 model (InstaNovoPlus):
494 The multinomial diffusion class.
495 """
497 def __init__(self, model: InstaNovoPlus) -> None:
498 super().__init__()
499 self.model = model
501 self.base_model = model.module if hasattr(model, "module") else model
503 self.time_steps = self.base_model.time_steps
505 @staticmethod
506 def kl_divergence(
507 log_probs_first: Float[ResidueLogProbabilities, "..."],
508 log_probs_second: Float[ResidueLogProbabilities, "..."],
509 ) -> Float[torch.Tensor, "..."]:
510 """Calculate the Kullback-Liebler divergence between two multinomial distributions.
512 Args:
513 log_probs_first (torch.FloatTensor[..., num_classes]):
514 The log-probabilities of the base distribution.
516 log_probs_second (torch.FloatTensor[..., num_classes]):
517 The log-probabilities of the comparison distribution.
519 Returns:
520 torch.FloatTensor[1]:
521 The KL-divergence averaged over all but the final dimension.
522 """
523 return (torch.exp(log_probs_first) * (log_probs_first - log_probs_second)).sum(-1).sum(-1)
525 def forward(self, x_0: Integer[Peptide, "batch token"], **kwargs: dict) -> Float[torch.Tensor, "1"]:
526 """Calculate a single Monte Carlo estimate of the multinomial diffusion loss (-ELBO).
528 Args:
529 x_0 (torch.LongTensor[batch_size, sequence_length]):
530 A batch of padded sequences.
532 Returns:
533 torch.FloatTensor[1]:
534 The loss estimate.
535 """
536 # 1. Sample time step
537 t = torch.randint(0, self.time_steps - 1, (x_0.shape[0],)).to(x_0.device)
539 # 2. Compute L_t
540 loss = self._compute_loss(t=t, x_0=x_0, **kwargs).mean()
542 # 3. Calculate prior KL term
543 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set)))
544 final_log_probs = self.base_model.mixture_categorical(
545 log_x=log_x_0,
546 log_alpha=self.base_model.cumulative_schedule[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1),
547 log_alpha_complement=self.base_model.cumulative_schedule_complement[self.time_steps - 1].unsqueeze(-1).unsqueeze(-1),
548 )
549 uniform_log_probs = torch.log(torch.ones_like(final_log_probs) / len(self.base_model.residue_set))
550 kl_loss = self.kl_divergence(final_log_probs, uniform_log_probs).mean()
551 return loss + kl_loss
553 def _compute_loss(
554 self,
555 x_0: Integer[Peptide, "batch token"],
556 t: Integer[TimeStep, " batch"],
557 **kwargs: dict,
558 ) -> Float[torch.Tensor, " batch"]:
559 # 1. sample x_{t+1}
560 log_x_0 = torch.log(one_hot(x_0, num_classes=len(self.base_model.residue_set)))
561 log_probs = self.base_model.mixture_categorical(
562 log_x=log_x_0,
563 log_alpha=self.base_model.cumulative_schedule[t].unsqueeze(-1).unsqueeze(-1),
564 log_alpha_complement=self.base_model.cumulative_schedule_complement[t].unsqueeze(-1).unsqueeze(-1),
565 )
566 x_next = Categorical(logits=log_probs).sample()
568 # 2. Calculate loss
569 log_dist = self.base_model.reverse_distribution(x_t=x_next, time=t, **kwargs)
571 nll_loss = -(one_hot(x_0, num_classes=len(self.base_model.residue_set)) * log_dist).sum(-1).sum(-1)
573 log_posterior = self.model(log_x_0=log_x_0, log_x_t=torch.log(one_hot(x_next, log_probs.size(-1))), t=t)
574 denoising_loss = self.kl_divergence(log_posterior, log_dist)
575 loss = torch.where(t == 0, nll_loss, denoising_loss)
576 return loss
579def _whitelist_torch_omegaconf() -> None:
580 """Whitelist specific classes so checkpoints can be loaded with ``weights_only=True``.
582 The single-file InstaNovo+ checkpoint stores its config as an OmegaConf
583 ``DictConfig``, which embeds a handful of non-tensor classes. Loading with
584 ``weights_only=True`` uses PyTorch's restricted unpickler, which refuses
585 unknown globals. We explicitly allow-list only the known-safe OmegaConf
586 classes (and the builtins they reference) so we keep the protection against
587 arbitrary code execution from untrusted checkpoints.
588 """
589 from collections import defaultdict
590 from typing import Any
592 from omegaconf.base import ContainerMetadata, Metadata
593 from omegaconf.listconfig import ListConfig
594 from omegaconf.nodes import AnyNode
596 torch.serialization.add_safe_globals(
597 [
598 DictConfig,
599 ContainerMetadata,
600 Metadata,
601 ListConfig,
602 AnyNode,
603 Any, # Only used for type hinting in omegaconf.
604 defaultdict,
605 dict,
606 list,
607 int,
608 ]
609 )