NequIP¶
- class mlip.models.nequip.config.NequipConfig(*, add_atomic_energies: bool = True, num_layers: Annotated[int, ~annotated_types.Gt(gt=0)] = 2, target_irreps: Annotated[str, ~pydantic.functional_validators.AfterValidator(func=~mlip.typing.fields._check_irreps)] = '128x0e + 128x0o + 64x1o + 64x1e + 4x2e + 4x2o', l_max: Annotated[int, ~annotated_types.Ge(ge=0)] = 3, num_rbf: Annotated[int, ~annotated_types.Gt(gt=0)] = 8, radial_envelope: RadialEnvelope = RadialEnvelope.POLYNOMIAL, radial_mlp_hidden: list[~typing.Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])]] = <factory>, radial_mlp_activation: Activation = Activation.SWISH, radial_mlp_variance_scale: Annotated[float, ~annotated_types.Gt(gt=0)] = 4.0, num_readout_heads: Annotated[int, ~annotated_types.Gt(gt=0)] = 1, avg_num_neighbors: float | None = None, num_species: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, predict_partial_charges: bool = False, use_coulomb_term: bool = False, use_total_charge_embedding: bool = False, embed_activation: Activation = Activation.SILU, use_residual_connection: bool = True, gate_nonlinearities: dict[str, ~mlip.models.options.Activation]={'e': Activation.SWISH, 'o': Activation.TANH}, deterministic_scatter_ops: bool = False)¶
The configuration / hyperparameters of the NequIP model.
- num_layers¶
Number of NequIP layers. Default is 2.
- Type:
int
- target_irreps¶
Target O3 representation space for node features at each layer, with number of channels that may depend on the degree
l. Each layer attempts to produce these irreps, filtered to what is reachable via the tensor product. Default"128x0e + 128x0o + 64x1o + 64x1e + 4x2e + 4x2o".- Type:
str
- l_max¶
Maximal degree of spherical harmonics used for the angular encoding of edge vectors. Default is 3.
- Type:
int
- num_rbf¶
The number of Bessel basis functions to use (default is 8).
- Type:
int
- radial_envelope¶
The radial envelope function, by default it is
"polynomial_envelope". The only other option is"soft_envelope".
Sizes of the MLP hidden layers. Default is [64, 64].
- Type:
list[int]
- radial_mlp_activation¶
Activation function for radial MLP. Default is
swish.
- radial_mlp_variance_scale¶
Variance scaling parameter passed to the fan-in normal initializer of the MLP internal layers. See
jax.nn.initializers.variance_scaling. Default is 4.0.- Type:
float
- num_readout_heads¶
Number of readout heads. The default is 1.
- Type:
int
- add_atomic_energies¶
Whether to add atomic energies to the final energies. Default is
True.- Type:
bool
- avg_num_neighbors¶
The mean number of neighbors for atoms. If
None(default), use the value from the dataset info. It is used to rescale messages by this value.- Type:
float | None
- num_species¶
The number of elements (atomic species descriptors) allowed. If
None(default), infer the value from the atomic energies map in the dataset info.- Type:
int | None
- predict_partial_charges¶
Whether the model will be trained to predict charges.
- Type:
bool
- use_coulomb_term¶
Whether to use the Coulomb term in the model for long range interactions. Default is False.
- Type:
bool
- use_total_charge_embedding¶
Whether to use the total charge embedding. Default is False.
- Type:
bool
- embed_activation¶
Activation function for the embedding block. Default is “silu”.
- use_residual_connection¶
Whether to add the species-linear residual connection in each NequIP layer. Default is
True.- Type:
bool
- gate_nonlinearities¶
Per-parity activations used by the gate nonlinearity in each NequIP layer. Keys are
"e"(even) and"o"(odd). Default is{"e": Activation.SWISH, "o": Activation.TANH}.- Type:
dict[str, mlip.models.options.Activation]
- deterministic_scatter_ops¶
Whether to use deterministic scatter operations in the forward pass, ensuring deterministic energy outputs. Setting to
Truemakes prediction slower. Default isFalse.- Type:
bool
- class mlip.models.nequip.network.Nequip(config: NequipConfig, dataset_info: DatasetInfo, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
The NequIP model flax module. It is derived from the
MLIPNetworkclass.References
Simon Batzner, Albert Musaelian, Lixin Sun, Mario Geiger, Jonathan P. Mailoa, Mordechai Kornbluth, Nicola Molinari, Tess E. Smidt, and Boris Kozinsky. E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials. Nature Communications, 13(1), May 2022. ISSN: 2041-1723. URL: https://dx.doi.org/10.1038/s41467-022-29939-5.
- config¶
Hyperparameters / configuration for the NequIP model, see
NequipConfig.
- dataset_info¶
Hyperparameters dictated by the dataset (e.g., cutoff radius or average number of neighbors).
- class mlip.models.nequip.layer.NequipLayer(target_irreps: ~e3nn_jax._src.irreps.Irreps, source_node_irreps: ~e3nn_jax._src.irreps.Irreps, l_max: int, num_species: int, num_rbf: int, use_residual_connection: bool, nonlinearities: dict[str, ~mlip.models.options.Activation], radial_mlp_activation: ~mlip.models.options.Activation | ~typing.Literal['beta_swish'], radial_mlp_hidden: list[int], radial_mlp_variance_scale: float, avg_num_neighbors: float, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
NequIP Layer, consisting of a convolution block and a gate nonlinearity.
Adapted from Google DeepMind materials discovery: https://github.com/google-deepmind/materials_discovery/blob/main/model/nequip.py
Implementation follows the original paper by Batzner et al. (2022): nature.com/articles/s41467-022-29939-5 and partially https://github.com/mir-group/nequip.
- Parameters:
target_irreps – Target irreps for the output node features. Acts as an upper bound — only paths reachable via the tensor product are included, so earlier layers may not fully achieve these irreps.
source_node_irreps – Expected irreps of the input node features. Used to validate the graph at call time.
l_max – Maximum degree of spherical harmonics. Used to validate the edge spherical embedding shape.
num_rbf – Number of radial basis functions. Used to validate the radial embedding shape in the convolution block.
use_residual_connection – If use residual connection in network (recommended).
nonlinearities – Nonlinearities to use for even/odd irreps.
radial_mlp_activation – Activation for the radial MLP.
radial_mlp_hidden – Dimensions of hidden layers of the radial MLP.
radial_mlp_variance_scale – Variance scaling for all-but-last layers of the radial MLP.
avg_num_neighbors – Constant number of per-atom neighbors, used for internal normalization.
deterministic_scatter_ops – Whether to use deterministic scatter operations.
- Returns:
Graph containing updated node features after the convolution and gating.
- __call__(graph: Graph) Graph¶
Apply the Nequip layer to update the node features of the input graph.
- Parameters:
graph – Graph containing precomputed node and edge embeddings, and current node features
graph.nodes.features["latent"], or if this is the first layergraph.nodes.features["embedding"].- Returns:
Graph with updated node features
graph.nodes.features["latent"].
- class mlip.models.nequip.blocks.NequipEmbeddingBlock(num_species: int, num_charges: int | None, target_irreps: str, l_max: int, num_rbf: int, r_max: float, avg_r_min: float | None, radial_envelope: RadialEnvelope, activation_fn: Callable | None, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Initial embedding block for the NequIP model.
Encodes atomic species as one-hot vectors and projects them to the hidden node feature space. Encodes edge geometry into radial basis functions and spherical harmonics, which together form the edge embedding used in the interaction layers.
- num_species¶
Number of distinct atomic species in the dataset.
- Type:
int
- target_irreps¶
Target irreps for the initial node feature projection.
- Type:
str
- l_max¶
Maximum degree of spherical harmonics for edge angular encoding.
- Type:
int
- num_rbf¶
Number of Bessel radial basis functions.
- Type:
int
- r_max¶
Cutoff distance in Angstroms.
- Type:
float
- avg_r_min¶
Average minimum interatomic distance, used to shift the radial basis. If None, no shift is applied.
- Type:
float | None
- radial_envelope¶
Envelope function to smoothly decay the basis at
r_max(e.g. polynomial or soft envelope).
- __call__(graph: Graph) Graph¶
Apply the NequIP embedding block to compute initial node and edge embeddings.
- Parameters:
graph – Graph containing node features with “species” and edge vectors.
- Returns:
Updated graph containing node features (“species_one_hot”, “embedding”), and edge features (“radial_embedding”, “spherical_embedding”).
- class mlip.models.nequip.blocks.NequipMultiHeadReadoutBlock(source_node_irreps: Irreps, num_heads: int, predict_partial_charges: bool, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Readout block that maps final node features to per-atom energies.
Does input shape assertions, then prepares the correct feature sizes for the linear layers and then calls the
ReadoutBlockwith it.Extracts the scalar (0e) channels from the final node feature irreps and applies two successive linear projections: first halving the scalar multiplicity, then reducing to a single scalar per atom. The result is stored as “outputs” in the node features.
- source_node_irreps¶
Expected irreps of the input node features, used to validate the graph at call time.
- Type:
e3nn_jax._src.irreps.Irreps