Skip to content

Model

InstaNovo(i2s, residues, dim_model=768, n_head=16, dim_feedforward=2048, n_layers=9, dropout=0.1, max_length=30, max_charge=5, bos_id=1, eos_id=2, use_depthcharge=True, enc_type='depthcharge', dec_type='depthcharge', dec_precursor_sos=False)

Bases: Module

The Instanovo model.

Source code in instanovo/transformer/model.py
def __init__(
    self,
    i2s: dict[int, str],
    residues: dict[str, float],
    dim_model: int = 768,
    n_head: int = 16,
    dim_feedforward: int = 2048,
    n_layers: int = 9,
    dropout: float = 0.1,
    max_length: int = 30,
    max_charge: int = 5,
    bos_id: int = 1,
    eos_id: int = 2,
    use_depthcharge: bool = True,
    enc_type: str = "depthcharge",
    dec_type: str = "depthcharge",
    dec_precursor_sos: bool = False,
) -> None:
    super().__init__()
    self.i2s = i2s
    self.n_vocab = len(self.i2s)
    self.residues = residues
    self.bos_id = bos_id  # beginning of sentence ID, prepend to y
    self.eos_id = eos_id  # stop token
    self.pad_id = 0
    self.use_depthcharge = use_depthcharge

    self.enc_type = enc_type
    self.dec_type = dec_type
    self.dec_precursor_sos = dec_precursor_sos
    self.peptide_mass_calculator = depthcharge.masses.PeptideMass(self.residues)

    self.latent_spectrum = nn.Parameter(torch.randn(1, 1, dim_model))

    # Encoder
    if self.enc_type == "depthcharge":
        self.encoder = SpectrumEncoder(
            dim_model=dim_model,
            n_head=n_head,
            dim_feedforward=dim_feedforward,
            n_layers=n_layers,
            dropout=dropout,
            dim_intensity=None,
        )
        if not self.dec_precursor_sos:
            self.mass_encoder = MassEncoder(dim_model)
            self.charge_encoder = nn.Embedding(max_charge, dim_model)

    else:
        if not self.use_depthcharge:
            self.peak_encoder = MultiScalePeakEmbedding(dim_model, dropout=dropout)
            self.mass_encoder = self.peak_encoder.encode_mass
        else:
            self.mass_encoder = MassEncoder(dim_model)
            self.peak_encoder = PeakEncoder(dim_model, dim_intensity=None)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim_model,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            batch_first=True,
            dropout=dropout,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers,
            enable_nested_tensor=False,
        )

    # Decoder
    if dec_type == "depthcharge":
        self.decoder = PeptideDecoder(
            dim_model=dim_model,
            n_head=n_head,
            dim_feedforward=dim_feedforward,
            n_layers=n_layers,
            dropout=dropout,
            residues=residues,
            max_charge=max_charge,
        )

        if not dec_precursor_sos:
            del self.decoder.charge_encoder
            self.decoder.charge_encoder = lambda x: torch.zeros(
                x.shape[0], dim_model, device=x.device
            )
            self.sos_embedding = nn.Parameter(torch.randn(1, 1, dim_model))
            del self.decoder.mass_encoder
            self.decoder.mass_encoder = lambda x: self.sos_embedding.expand(x.shape[0], -1, -1)
    else:
        self.aa_embed = nn.Embedding(self.n_vocab, dim_model, padding_idx=0)
        if not self.use_depthcharge:
            self.aa_pos_embed = PositionalEncoding(dim_model, dropout, max_len=200)
            assert max_length <= 200  # update value if necessary
        else:
            self.aa_pos_embed = PositionalEncoder(dim_model)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=dim_model,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            batch_first=True,
            dropout=dropout,
            # norm_first=True,
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=n_layers,
        )

        self.head = nn.Linear(dim_model, self.n_vocab)
        self.charge_encoder = nn.Embedding(max_charge, dim_model)

    if self.dec_type == "depthcharge":
        self.eos_id = self.decoder._aa2idx["$"]

batch_idx_to_aa(idx)

Decode a batch of indices to aa lists.

Source code in instanovo/transformer/model.py
def batch_idx_to_aa(self, idx: Tensor) -> list[list[str]]:
    """Decode a batch of indices to aa lists."""
    return [self.idx_to_aa(i) for i in idx]

decode(sequence)

Decode a single sequence of AA IDs.

Source code in instanovo/transformer/model.py
def decode(self, sequence: torch.LongTensor) -> list[str]:
    """Decode a single sequence of AA IDs."""
    return self.decoder.detokenize(sequence)  # type: ignore

forward(x, p, y, x_mask=None, y_mask=None, add_bos=True)

Model forward pass.

Parameters:

Name Type Description Default
x Tensor

Spectra, float Tensor (batch, n_peaks, 2)

required
p Tensor

Precursors, float Tensor (batch, 3)

required
y Tensor

Peptide, long Tensor (batch, seq_len, vocab)

required
x_mask Tensor

Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)

None
y_mask Tensor

Peptide padding mask, bool Tensor (batch, seq_len)

None
add_bos bool

Force add a prefix to y, bool

True

Returns:

Name Type Description
logits Tensor

float Tensor (batch, n, vocab_size),

Tensor

(batch, n+1, vocab_size) if add_bos==True.

Source code in instanovo/transformer/model.py
def forward(
    self,
    x: Tensor,
    p: Tensor,
    y: Tensor,
    x_mask: Tensor = None,
    y_mask: Tensor = None,
    add_bos: bool = True,
) -> Tensor:
    """Model forward pass.

    Args:
        x: Spectra, float Tensor (batch, n_peaks, 2)
        p: Precursors, float Tensor (batch, 3)
        y: Peptide, long Tensor (batch, seq_len, vocab)
        x_mask: Spectra padding mask, True for padded indices, bool Tensor (batch, n_peaks)
        y_mask: Peptide padding mask, bool Tensor (batch, seq_len)
        add_bos: Force add a <s> prefix to y, bool

    Returns:
        logits: float Tensor (batch, n, vocab_size),
        (batch, n+1, vocab_size) if add_bos==True.
    """
    x, x_mask = self._encoder(x, p, x_mask)
    return self._decoder(x, p, y, x_mask, y_mask, add_bos)

get_empty_index()

Get the PAD token ID.

Source code in instanovo/transformer/model.py
def get_empty_index(self) -> int:
    """Get the PAD token ID."""
    return 0

get_eos_index()

Get the EOS token ID.

Source code in instanovo/transformer/model.py
def get_eos_index(self) -> int:
    """Get the EOS token ID."""
    return self.eos_id

get_residue_masses(mass_scale)

Get the scaled masses of all residues.

Source code in instanovo/transformer/model.py
def get_residue_masses(self, mass_scale: int) -> torch.LongTensor:
    """Get the scaled masses of all residues."""
    residue_masses = torch.zeros(max(self.decoder._idx2aa.keys()) + 1).type(torch.int64)
    for index, residue in self.decoder._idx2aa.items():
        if residue in self.peptide_mass_calculator.masses:
            residue_masses[index] = round(
                mass_scale * self.peptide_mass_calculator.masses[residue]
            )
    return residue_masses

idx_to_aa(idx)

Decode a single sample of indices to aa list.

Source code in instanovo/transformer/model.py
def idx_to_aa(self, idx: Tensor) -> list[str]:
    """Decode a single sample of indices to aa list."""
    idx = idx.cpu().numpy()
    t = []
    for i in idx:
        if i == self.eos_id:
            break
        if i == self.bos_id or i == self.pad_id:
            continue
        t.append(i)
    return [self.i2s[x.item()] for x in t]

init(x, p, x_mask=None)

Initialise model encoder.

Source code in instanovo/transformer/model.py
def init(
    self, x: Tensor, p: Tensor, x_mask: Tensor = None
) -> tuple[tuple[Tensor, Tensor], Tensor]:
    """Initialise model encoder."""
    x, x_mask = self._encoder(x, p, x_mask)
    # y = torch.ones((x.shape[0], 1), dtype=torch.long, device=x.device) * self.bos_id
    logits, _ = self._decoder(x, p, None, x_mask, None, add_bos=False)
    return (x, x_mask), torch.log_softmax(logits[:, -1, :], -1)

load(path) classmethod

Load model from checkpoint.

Source code in instanovo/transformer/model.py
@classmethod
def load(cls, path: str) -> nn.Module:
    """Load model from checkpoint."""
    ckpt = torch.load(path, map_location="cpu")

    config = ckpt["config"]

    # check if PTL checkpoint
    if all([x.startswith("model") for x in ckpt["state_dict"].keys()]):
        ckpt["state_dict"] = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}

    i2s = {i: v for i, v in enumerate(config["vocab"])}

    model = cls(
        i2s=i2s,
        residues=config["residues"],
        dim_model=config["dim_model"],
        n_head=config["n_head"],
        dim_feedforward=config["dim_feedforward"],
        n_layers=config["n_layers"],
        dropout=config["dropout"],
        max_length=config["max_length"],
        max_charge=config["max_charge"],
        use_depthcharge=config["use_depthcharge"],
        enc_type=config["enc_type"],
        dec_type=config["dec_type"],
        dec_precursor_sos=config["dec_precursor_sos"],
    )
    model.load_state_dict(ckpt["state_dict"])

    return model, config

score_candidates(y, p, x, x_mask)

Score a set of candidate sequences.

Source code in instanovo/transformer/model.py
def score_candidates(
    self,
    y: torch.LongTensor,
    p: torch.FloatTensor,
    x: torch.FloatTensor,
    x_mask: torch.BoolTensor,
) -> torch.FloatTensor:
    """Score a set of candidate sequences."""
    logits, _ = self._decoder(x, p, y, x_mask, None, add_bos=True)

    return torch.log_softmax(logits[:, -1, :], -1)