ViSNet

class mlip.models.visnet.network.Visnet(config: VisnetConfig, dataset_info: DatasetInfo, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

The ViSNet model flax module. It is derived from the MLIPNetwork class.

References

  • Yusong Wang, Tong Wang, Shaoning Li, Xinheng He, Mingyu Li, Zun Wang, Nanning Zheng, Bin Shao, and Tie-Yan Liu. Enhancing geometric representations for molecules with equivariant vector-scalar interactive message passing. Nature Communications, 15(1), January 2024. ISSN: 2041-1723. URL: https://dx.doi.org/10.1038/s41467-023-43720-2.

config

Hyperparameters / configuration for the ViSNet model, see VisnetConfig.

Type:

mlip.models.visnet.config.VisnetConfig

dataset_info

Hyperparameters dictated by the dataset (e.g., cutoff radius or average number of neighbors).

Type:

mlip.data.dataset_info.DatasetInfo

available_properties

Model available properties, see Properties.

setup() None

Initializes model layers and checks that the number of hidden channels is evenly divisible by the number of attention heads.

__call__(graph: Graph) Graph

Runs the ViSNet model forward pass on an input Graph and returns an updated Graph object with node-wise contributions.

Features in output graph: - node-wise energy : graph.nodes.features[“energy”]

Parameters:

graph – Input Graph containing atomic positions, species, and topology.

Returns:

Updated Graph with node-wise features.

class mlip.models.visnet.config.VisnetConfig(*, add_atomic_energies: bool = True, num_layers: Annotated[int, Gt(gt=0)] = 4, num_channels: Annotated[int, Gt(gt=0)] = 256, l_max: Annotated[int, Ge(ge=0)] = 2, num_heads: Annotated[int, Gt(gt=0)] = 8, num_rbf: Annotated[int, Gt(gt=0)] = 32, trainable_rbf: bool = False, activation: Activation = Activation.SILU, attn_activation: Activation = Activation.SILU, vecnorm_type: VecNormType = VecNormType.NONE, num_readout_heads: Annotated[int, Gt(gt=0)] = 1, radial_basis: RadialBasis | str = RadialBasis.EXPNORM, num_species: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, predict_partial_charges: bool = False, use_coulomb_term: bool = False, use_total_charge_embedding: bool = False, embed_activation: Activation = Activation.SILU, deterministic_scatter_ops: bool = False)

Hyperparameters for the ViSNet model.

num_layers

Number of ViSNet layers. Default is 2.

Type:

int

num_channels

The number of channels. Default is 256.

Type:

int

l_max

Highest harmonic order included in the Spherical Harmonics series. Default is 2.

Type:

int

num_heads

Number of heads in the attention block. Default is 8.

Type:

int

num_rbf

Number of basis functions used in the embedding block. Default is 32.

Type:

int

trainable_rbf

Whether to add learnable weights to each of the radial embedding basis functions. Default is False.

Type:

bool

activation

Activation function for the output block. Options are “silu” (default), “ssp” (which is shifted softplus), “tanh”, “sigmoid”, and “swish”.

Type:

mlip.models.options.Activation

attn_activation

Activation function for the attention block. Options are “silu” (default), “ssp” (which is shifted softplus), “tanh”, “sigmoid”, and “swish”.

Type:

mlip.models.options.Activation

vecnorm_type

The type of the vector norm. The options are “none” (default), “max_min”, and “rms”.

Type:

mlip.models.visnet.visnet_helpers.VecNormType

num_readout_heads

Number of readout heads. The default is 1.

Type:

int

radial_basis

The type of radial embedding. Options are “bessel”, “gauss” and “expnorm” (default).

Type:

mlip.models.options.RadialBasis | str

add_atomic_energies

Whether to add atomic energies to the final energies. Default is True.

Type:

bool

num_species

The number of elements (atomic species descriptors) allowed. If None (default), infer the value from the atomic energies map in the dataset info.

Type:

int | None

predict_partial_charges

Whether the model will be trained to predict charges.

Type:

bool

use_coulomb_term

Whether to use the Coulomb term in the model for long range interactions. Default is False.

Type:

bool

use_total_charge_embedding

Whether to use the total charge embedding. Default is False.

Type:

bool

embed_activation

Activation function for the embedding block. Default is “silu”.

Type:

mlip.models.options.Activation

deterministic_scatter_ops

Whether to use deterministic scatter operations in the forward pass, ensuring deterministic energy outputs. Setting to True makes prediction slower. Default is False.

Type:

bool

class mlip.models.visnet.blocks.VisnetEmbeddingBlock(l_max: int, num_channels: int, num_rbf: int, trainable_rbf: bool, graph_cutoff_angstrom: float, num_species: int, num_charges: int | None, radial_basis: str | RadialBasis, activation_fn: Callable | None, deterministic_scatter_ops: bool = False, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Embeds input node and edge features for the ViSNet model.

Initializes and applies the embedding layers for node species, radial functions, and spherical harmonics. Updates the input Graph with the embedded features and pre-processes edges and neighbors for subsequent network layers.

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

num_channels

The number of channels.

Type:

int

num_rbf

Number of basis functions used in the embedding block.

Type:

int

trainable_rbf

Whether to add learnable weights to each of the radial embedding basis functions.

Type:

bool

graph_cutoff_angstrom

The cutoff radius for the graph.

Type:

float

num_species

The number of elements (atomic species descriptors) allowed.

Type:

int

deterministic_scatter_ops

Whether to use deterministic scatter operations.

Type:

bool

setup() None

Initializes the embedding layers for node species, radial functions, and spherical harmonics, neighbor and edge embedding blocks.

__call__(graph: Graph) Graph

Applies embedding transformations to the input graph.

Updates the graph with node, edge, and spherical harmonic features, processes neighbor and edge information for ViSNet layers, and initializes vector features. Returns the updated graph.

Embedded features in the final graph can be accessed as: - scalar node features: graph.nodes.features[“embedding_scalars”] - vector features: graph.nodes.features[“embedding_vectors”] - edge features: graph.edges.features[“embedding”] - edge distances: graph.edges.features[“distances”] - spherical harmonic features: graph.edges.features[“spherical_embedding”]

Parameters:

graph – Input Graph object with atomic positions, node species, and edge vectors.

Returns:

Updated Graph with embedded node and edge features ready for ViSNet processing.

class mlip.models.visnet.blocks.VisnetNeighborEmbeddingBlock(num_channels: int, graph_cutoff_angstrom: float, num_species: int, num_rbf: int, deterministic_scatter_ops: bool = False, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Applies the neighbor embedding to update node features using neighboring nodes’ species and edge features.

num_channels

The number of channels.

Type:

int

graph_cutoff_angstrom

The cutoff radius for the graph.

Type:

float

num_species

The number of elements (atomic species descriptors) allowed.

Type:

int

num_rbf

Number of basis functions used in the embedding block.

Type:

int

deterministic_scatter_ops

Whether to use deterministic scatter operations.

Type:

bool

setup() None

Initializes the neighbor embedding layers.

__call__(graph: Graph) Graph

Applies the neighbor embedding to update node features using neighboring nodes’ species and edge features.

Parameters:

graph – Input Graph containing node features (species, node_feats), edge features (edge_feats, distances), and graph topology (senders, receivers).

Returns:

Updated Graph object with new node features (field name: “embedding_scalars”).

class mlip.models.visnet.blocks.VisnetEdgeEmbeddingBlock(num_channels: int, num_rbf: int, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Applies the edge embedding to update edge features using node features.

num_channels

The number of channels.

Type:

int

num_rbf

Number of basis functions used in the embedding block.

Type:

int

setup() None

Initializes the edge embedding layers.

__call__(graph: Graph) Graph

Applies the edge embedding to update edge features using node features.

Parameters:

graph – Input Graph containing node features, edge features, and graph topology (senders, receivers).

Returns:

“embedding”).

Return type:

Updated Graph object with new edge features (field name

class mlip.models.visnet.blocks.VisnetMultiHeadReadoutBlock(num_heads: int, num_channels: int, activation: str, vecnorm_type: str, l_max: int, predict_partial_charges: bool, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Applies the final readout processing network to node and edge features.

num_heads

The number of readout heads.

Type:

int

num_channels

The number of channels.

Type:

int

activation

The activation function.

Type:

str

vecnorm_type

The type of vector normalization to apply.

Type:

str

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

predict_partial_charges

Whether to predict partial charges.

Type:

bool

setup() None

Initializes the output processing network.

__call__(graph: Graph) Graph

Applies the final output processing network to node and edge features.

Passes the node features and edge vector features through a stack of GatedEquivariantBlock layers to produce final per-node outputs.

Parameters:

graph – Input Graph object containing scalar node features (“latent_scalars”) and vector node features (“latent_vectors”).

Returns:

Updated Graph object with processed node features (“latent_scalars”) of shape [num_nodes, num_heads, Nx0e].

class mlip.models.visnet.layer.VisnetLayer(num_heads: int, num_channels: int, activation: str, attn_activation: str, graph_cutoff_angstrom: float, vecnorm_type: str, last_layer: bool, l_max: int, deterministic_scatter_ops: bool = False, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

VisnetLayer module representing a single vector-scalar interactive self-attention layer used in ViSNet.

This layer performs equivariant message passing with multiple attention heads, supporting both scalar and vector features and including various normalization and activation options as configured.

num_heads

Number of attention heads.

Type:

int

num_channels

Number of channels in the input and output features.

Type:

int

activation

Activation function.

Type:

str

attn_activation

Activation function for the attention heads.

Type:

str

graph_cutoff_angstrom

Cutoff radius.

Type:

float

vecnorm_type

Type of vector normalization to apply.

Type:

str

last_layer

Whether this is the last layer of the network.

Type:

bool

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

deterministic_scatter_ops

Whether to use deterministic scatter operations.

Type:

bool

setup() None

Initializes the VisnetLayer module.

__call__(graph: Graph) Graph

Applies the VisnetLayer module to an input Graph and returns a Graph.

Runs the forward pass of the VisnetLayer module on an input Graph and returns an updated Graph with node and edge features processed using vector-scalar attention, normalization, and projection mechanisms. Residual connection are applied to the updated features.

Updated features in this function: - scalar node-wise features: graph.nodes.features[“latent_scalars”] - vector node-wise features: graph.nodes.features[“latent_vectors”] - edge-wise features: graph.edges.features[“latent”]

Parameters:

graph – Input Graph object containing node features (“latent_scalars”, “latent_vectors”), edge features (“latent”, “distances”, “spherical_embedding”), and topology (“senders”, “receivers”). If this is the first layer in a model, the input features can also be named “embedding_*” instead of “latent_*”.

Returns:

Updated Graph object with new node and edge features after message passing and attention updates.