NequIP

class mlip.models.nequip.config.NequipConfig(*, add_atomic_energies: bool = True, num_layers: Annotated[int, ~annotated_types.Gt(gt=0)] = 2, target_irreps: Annotated[str, ~pydantic.functional_validators.AfterValidator(func=~mlip.typing.fields._check_irreps)] = '128x0e + 128x0o + 64x1o + 64x1e + 4x2e + 4x2o', l_max: Annotated[int, ~annotated_types.Ge(ge=0)] = 3, num_rbf: Annotated[int, ~annotated_types.Gt(gt=0)] = 8, radial_envelope: RadialEnvelope = RadialEnvelope.POLYNOMIAL, radial_mlp_hidden: list[~typing.Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])]] = <factory>, radial_mlp_activation: Activation = Activation.SWISH, radial_mlp_variance_scale: Annotated[float, ~annotated_types.Gt(gt=0)] = 4.0, num_readout_heads: Annotated[int, ~annotated_types.Gt(gt=0)] = 1, avg_num_neighbors: float | None = None, num_species: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, predict_partial_charges: bool = False, use_coulomb_term: bool = False, use_total_charge_embedding: bool = False, embed_activation: Activation = Activation.SILU, use_residual_connection: bool = True, gate_nonlinearities: dict[str, ~mlip.models.options.Activation]={'e': Activation.SWISH, 'o': Activation.TANH}, deterministic_scatter_ops: bool = False)

The configuration / hyperparameters of the NequIP model.

num_layers

Number of NequIP layers. Default is 2.

Type:

int

target_irreps

Target O3 representation space for node features at each layer, with number of channels that may depend on the degree l. Each layer attempts to produce these irreps, filtered to what is reachable via the tensor product. Default "128x0e + 128x0o + 64x1o + 64x1e + 4x2e + 4x2o".

Type:

str

l_max

Maximal degree of spherical harmonics used for the angular encoding of edge vectors. Default is 3.

Type:

int

num_rbf

The number of Bessel basis functions to use (default is 8).

Type:

int

radial_envelope

The radial envelope function, by default it is "polynomial_envelope". The only other option is "soft_envelope".

Type:

mlip.models.options.RadialEnvelope

radial_mlp_hidden

Sizes of the MLP hidden layers. Default is [64, 64].

Type:

list[int]

radial_mlp_activation

Activation function for radial MLP. Default is swish.

Type:

mlip.models.options.Activation

radial_mlp_variance_scale

Variance scaling parameter passed to the fan-in normal initializer of the MLP internal layers. See jax.nn.initializers.variance_scaling. Default is 4.0.

Type:

float

num_readout_heads

Number of readout heads. The default is 1.

Type:

int

add_atomic_energies

Whether to add atomic energies to the final energies. Default is True.

Type:

bool

avg_num_neighbors

The mean number of neighbors for atoms. If None (default), use the value from the dataset info. It is used to rescale messages by this value.

Type:

float | None

num_species

The number of elements (atomic species descriptors) allowed. If None (default), infer the value from the atomic energies map in the dataset info.

Type:

int | None

predict_partial_charges

Whether the model will be trained to predict charges.

Type:

bool

use_coulomb_term

Whether to use the Coulomb term in the model for long range interactions. Default is False.

Type:

bool

use_total_charge_embedding

Whether to use the total charge embedding. Default is False.

Type:

bool

embed_activation

Activation function for the embedding block. Default is “silu”.

Type:

mlip.models.options.Activation

use_residual_connection

Whether to add the species-linear residual connection in each NequIP layer. Default is True.

Type:

bool

gate_nonlinearities

Per-parity activations used by the gate nonlinearity in each NequIP layer. Keys are "e" (even) and "o" (odd). Default is {"e": Activation.SWISH, "o": Activation.TANH}.

Type:

dict[str, mlip.models.options.Activation]

deterministic_scatter_ops

Whether to use deterministic scatter operations in the forward pass, ensuring deterministic energy outputs. Setting to True makes prediction slower. Default is False.

Type:

bool

class mlip.models.nequip.network.Nequip(config: NequipConfig, dataset_info: DatasetInfo, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

The NequIP model flax module. It is derived from the MLIPNetwork class.

References

  • Simon Batzner, Albert Musaelian, Lixin Sun, Mario Geiger, Jonathan P. Mailoa, Mordechai Kornbluth, Nicola Molinari, Tess E. Smidt, and Boris Kozinsky. E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials. Nature Communications, 13(1), May 2022. ISSN: 2041-1723. URL: https://dx.doi.org/10.1038/s41467-022-29939-5.

config

Hyperparameters / configuration for the NequIP model, see NequipConfig.

Type:

mlip.models.nequip.config.NequipConfig

dataset_info

Hyperparameters dictated by the dataset (e.g., cutoff radius or average number of neighbors).

Type:

mlip.data.dataset_info.DatasetInfo

__call__(graph: Graph) Graph

Apply the NequIP model to the input graph to compute per-node energies.

Parameters:

graph – Input graph containing node and edge information.

Returns:

Graph with per-node energy predictions graph.nodes.features["energy"].

class mlip.models.nequip.layer.NequipLayer(target_irreps: ~e3nn_jax._src.irreps.Irreps, source_node_irreps: ~e3nn_jax._src.irreps.Irreps, l_max: int, num_species: int, num_rbf: int, use_residual_connection: bool, nonlinearities: dict[str, ~mlip.models.options.Activation], radial_mlp_activation: ~mlip.models.options.Activation | ~typing.Literal['beta_swish'], radial_mlp_hidden: list[int], radial_mlp_variance_scale: float, avg_num_neighbors: float, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

NequIP Layer, consisting of a convolution block and a gate nonlinearity.

Adapted from Google DeepMind materials discovery: https://github.com/google-deepmind/materials_discovery/blob/main/model/nequip.py

Implementation follows the original paper by Batzner et al. (2022): nature.com/articles/s41467-022-29939-5 and partially https://github.com/mir-group/nequip.

Parameters:
  • target_irreps – Target irreps for the output node features. Acts as an upper bound — only paths reachable via the tensor product are included, so earlier layers may not fully achieve these irreps.

  • source_node_irreps – Expected irreps of the input node features. Used to validate the graph at call time.

  • l_max – Maximum degree of spherical harmonics. Used to validate the edge spherical embedding shape.

  • num_rbf – Number of radial basis functions. Used to validate the radial embedding shape in the convolution block.

  • use_residual_connection – If use residual connection in network (recommended).

  • nonlinearities – Nonlinearities to use for even/odd irreps.

  • radial_mlp_activation – Activation for the radial MLP.

  • radial_mlp_hidden – Dimensions of hidden layers of the radial MLP.

  • radial_mlp_variance_scale – Variance scaling for all-but-last layers of the radial MLP.

  • avg_num_neighbors – Constant number of per-atom neighbors, used for internal normalization.

  • deterministic_scatter_ops – Whether to use deterministic scatter operations.

Returns:

Graph containing updated node features after the convolution and gating.

__call__(graph: Graph) Graph

Apply the Nequip layer to update the node features of the input graph.

Parameters:

graph – Graph containing precomputed node and edge embeddings, and current node features graph.nodes.features["latent"], or if this is the first layer graph.nodes.features["embedding"].

Returns:

Graph with updated node features graph.nodes.features["latent"].

class mlip.models.nequip.blocks.NequipEmbeddingBlock(num_species: int, num_charges: int | None, target_irreps: str, l_max: int, num_rbf: int, r_max: float, avg_r_min: float | None, radial_envelope: RadialEnvelope, activation_fn: Callable | None, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Initial embedding block for the NequIP model.

Encodes atomic species as one-hot vectors and projects them to the hidden node feature space. Encodes edge geometry into radial basis functions and spherical harmonics, which together form the edge embedding used in the interaction layers.

num_species

Number of distinct atomic species in the dataset.

Type:

int

target_irreps

Target irreps for the initial node feature projection.

Type:

str

l_max

Maximum degree of spherical harmonics for edge angular encoding.

Type:

int

num_rbf

Number of Bessel radial basis functions.

Type:

int

r_max

Cutoff distance in Angstroms.

Type:

float

avg_r_min

Average minimum interatomic distance, used to shift the radial basis. If None, no shift is applied.

Type:

float | None

radial_envelope

Envelope function to smoothly decay the basis at r_max (e.g. polynomial or soft envelope).

Type:

mlip.models.options.RadialEnvelope

__call__(graph: Graph) Graph

Apply the NequIP embedding block to compute initial node and edge embeddings.

Parameters:

graph – Graph containing node features with “species” and edge vectors.

Returns:

Updated graph containing node features (“species_one_hot”, “embedding”), and edge features (“radial_embedding”, “spherical_embedding”).

class mlip.models.nequip.blocks.NequipMultiHeadReadoutBlock(source_node_irreps: Irreps, num_heads: int, predict_partial_charges: bool, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Readout block that maps final node features to per-atom energies.

Does input shape assertions, then prepares the correct feature sizes for the linear layers and then calls the ReadoutBlock with it.

Extracts the scalar (0e) channels from the final node feature irreps and applies two successive linear projections: first halving the scalar multiplicity, then reducing to a single scalar per atom. The result is stored as “outputs” in the node features.

source_node_irreps

Expected irreps of the input node features, used to validate the graph at call time.

Type:

e3nn_jax._src.irreps.Irreps

__call__(graph: Graph) Graph

Call self as a function.