Coverage for instanovo/inference/interfaces.py: 81%
43 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3from abc import ABCMeta, abstractmethod
4from dataclasses import dataclass
5from typing import Any
7import torch
8from jaxtyping import Float, Integer
10from instanovo.types import Peptide, PrecursorFeatures, Spectrum
11from instanovo.utils.residues import ResidueSet
14@dataclass
15class ScoredSequence:
16 """This class holds a residue sequence and its log probability."""
18 sequence: list[str]
19 mass_error: float
20 sequence_log_probability: float
21 token_log_probabilities: list[float]
24class Decodable(metaclass=ABCMeta):
25 """An interface for models that can be decoded.
27 Algorithms should conform to the search interface.
28 """
30 @property
31 @abstractmethod
32 def residue_set(self) -> ResidueSet:
33 """Every model must have a `residue_set` attribute."""
34 pass
36 @abstractmethod
37 def init( # type:ignore
38 self,
39 spectra: Float[Spectrum, " batch"],
40 precursors: Float[PrecursorFeatures, " batch"],
41 *args,
42 **kwargs,
43 ) -> Any:
44 """Initialize the search state.
46 Args:
47 spectra (torch.FloatTensor):
48 The spectra to be sequenced.
50 precursors (torch.FloatTensor[batch size, 3]):
51 The precursor mass, charge and mass-to-charge ratio.
52 """
53 pass
55 @abstractmethod
56 def score_candidates( # type:ignore
57 self,
58 sequences: Integer[Peptide, "..."],
59 precursor_mass_charge: Float[PrecursorFeatures, "..."],
60 *args,
61 **kwargs,
62 ) -> torch.FloatTensor:
63 """Generate and score the next set of candidates.
65 Args:
66 sequences (torch.LongTensor):
67 Partial residue sequences in generated
68 the course of decoding.
70 precursor_mass_charge (torch.FloatTensor[batch size, 3]):
71 The precursor mass, charge and mass-to-charge ratio.
72 """
73 pass
75 @abstractmethod
76 def get_residue_masses(self, mass_scale: int) -> torch.LongTensor:
77 """Get residue masses for the model's residue vocabulary.
79 Args:
80 mass_scale (int):
81 The scale in Daltons at which masses are
82 calculated and rounded off. For example,
83 a scale of 10000 would represent masses
84 at a scale of 1e4 Da.
85 """
86 pass
88 @abstractmethod
89 def decode(self, sequence: Integer[Peptide, "..."]) -> list[str]:
90 """Map sequences of indices to residues using the model's residue vocabulary.
92 Args:
93 sequence (torch.LongTensor):
94 The sequence of residue indices to be mapped
95 to the corresponding residue strings.
96 """
97 pass
99 @abstractmethod
100 def get_eos_index(self) -> int:
101 """Get the end of sequence token's index in the model's residue vocabulary."""
102 pass
104 @abstractmethod
105 def get_empty_index(self) -> int:
106 """Get the empty token's index in the model's residue vocabulary."""
107 pass
110class Decoder(metaclass=ABCMeta):
111 """A class that implements some search algorithm for decoding.
113 Model should conform to the `Decodable` interface.
115 Args:
116 model (Decodable):
117 The model to predict residue sequences
118 from using the implemented search
119 algorithm.
120 """
122 def __init__(self, model: Decodable):
123 self.model = model
125 @abstractmethod
126 def decode( # type:ignore
127 self,
128 spectra: Float[Spectrum, "..."],
129 precursors: Float[PrecursorFeatures, "..."],
130 *args,
131 **kwargs,
132 ) -> dict[str, Any]:
133 """Generate the predicted residue sequence using the decoder's search algorithm.
135 Args:
136 spectra (torch.FloatTensor):
137 The spectra to be sequenced.
139 precursors (torch.FloatTensor):
140 The precursor mass, charge and mass-to-charge ratio.
142 Returns:
143 dict[str, Any]:
144 Required keys:
145 - "sequence": list[str]
146 - "mass_error": float
147 - "sequence_log_probability": float
148 - "token_log_probabilities": list[float]
149 - "encoder_output": list[float] (optional)
150 Example additional keys:
151 - "sequence_beam_0": list[str]
152 """
153 pass