ESEN

class mlip.models.esen.network.Esen(config: EsenConfig, dataset_info: DatasetInfo, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

The Esen model flax module. It is derived from the MLIPNetwork class.

References

  • Saro Passaro, Lawrence Zitnick. Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs. URL: https://arxiv.org/pdf/2302.03655

  • Xiang Fu, Brandon M. Wood, Luis Barroso-Luque, Daniel S. Levine, Meng Gao, Misko Dzamba, C. Lawrence Zitnick. Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction. URL: https://arxiv.org/abs/2502.12147

  • Brandon M. Wood et al. UMA: A Family of Universal Models for Atoms. URL: https://arxiv.org/pdf/2506.23971

config

Hyperparameters / configuration for the Esen model, see EsenConfig.

Type:

mlip.models.esen.config.EsenConfig

dataset_info

Hyperparameters dictated by the dataset (e.g., cutoff radius or average number of neighbors).

Type:

mlip.data.dataset_info.DatasetInfo

available_properties

Model available properties, see Properties.

setup() None

Initializes the model layers.

__call__(graph: Graph) Graph

Runs the Esen model forward pass on an input Graph and returns an updated Graph object with node-wise contributions.

Features in output graph: - node-wise energy : graph.nodes.features[“energy”]

Parameters:

graph – Input Graph containing atomic positions and topology.

Returns:

Updated Graph with node-wise features.

class mlip.models.esen.config.EsenConfig(*, add_atomic_energies: bool = True, num_species: int | None = None, num_layers: int = 2, sphere_channels: int = 16, hidden_channels: int = 16, edge_channels: int = 16, l_max: int = 2, m_max: int = 2, radial_envelope: RadialEnvelope = RadialEnvelope.POLYNOMIAL, radial_basis: str | RadialBasis = RadialBasis.GAUSS, num_rbf: int = 16, basis_width_scalar: float = 2.0, cosine_cutoff: bool = False, norm_type: str = 'rms_norm_sh', act_type: str = 'gate', trainable_rbf: bool = False, num_readout_heads: Annotated[int, Gt(gt=0)] = 1, moe: EsenMoEConfig | None = None, predict_partial_charges: bool = False, use_coulomb_term: bool = False, use_total_charge_embedding: bool = False, embed_activation: Activation = Activation.SILU, deterministic_scatter_ops: bool = False)

The configuration / hyperparameters of the eSEN model.

num_species

The number of elements (atomic species descriptors) allowed. If None (default), infer the value from the atomic energies map in the dataset info.

Type:

int | None

num_layers

Number of eSEN layers. Default is 4.

Type:

int

sphere_channels

The number of channels for the node embedding. Default is 128.

Type:

int

hidden_channels

The number of channels outputs for convolution layers and MLPs in the network. Default is 128.

Type:

int

edge_channels

The number of channels for the edge embedding. Default is 128.

Type:

int

l_max

Highest degree of spherical harmonics used for the directional encoding of edge vectors, and during the convolution block. Default is 2.

Type:

int

m_max

Cap on m number in the convolution layer, m features above that order are removed. Default is 2.

Type:

int

add_atomic_energies

Whether to add atomic energies to the final energies. Default is True.

Type:

bool

radial_envelope

The radial envelope function, by default it is "polynomial_envelope". The only other option is "soft_envelope".

Type:

mlip.models.options.RadialEnvelope

radial_basis

Type of radial basis function used. Two options available: “bessel”, “gauss”, and “expnorm”. Default is “gauss”.

Type:

str | mlip.models.options.RadialBasis

num_rbf

Number of radial basis used in edge embedding. Default is set to 32 (512 used in UMA small)

Type:

int

cosine_cutoff

Whether to use the cosine cutoff envelope function in the radial embedding block. Defaults to False.

Type:

bool

norm_type

Specifies the type of normalisation used. Three options are available: “layer_norm”, “layer_norm_sh”, “rms_norm_sh”. Default is “rms_norm_sh”

Type:

str

act_type

Activation type for Edgewise (convolution). Only one option available, “gate”, used a default.

Type:

str

num_readout_heads

Number of readout heads. The default is 1.

Type:

int

moe

Optional MoE configuration. When None (default), eSEN behaves as a standard model with no expert routing.

Type:

mlip.models.esen.config.EsenMoEConfig | None

predict_partial_charges

Whether the model will be trained to predict charges.

Type:

bool

use_coulomb_term

Whether to use the Coulomb term in the model for long range interactions. Default is False.

Type:

bool

use_total_charge_embedding

Whether to use the total charge embedding. Default is False.

Type:

bool

embed_activation

Activation function for the embedding block. Default is “silu”.

Type:

mlip.models.options.Activation

deterministic_scatter_ops

Whether to use deterministic scatter operations in the forward pass, ensuring deterministic energy outputs. Setting to True makes prediction slower. Default is False.

Type:

bool

class mlip.models.esen.blocks.EsenEmbeddingBlock(graph_cutoff_angstrom: float, l_max: int, num_species: int, num_charges: int | None, sphere_channels: int, radial_envelope: ~mlip.models.options.RadialEnvelope | str, radial_basis: ~mlip.models.options.RadialBasis | str, num_rbf: int, basis_width_scalar: float, cosine_cutoff: bool, trainable_rbf: bool, edge_channels: int, m_max: int, mapping_reduced: ~mlip.models.esen.coefficient_mapping.CoefficientMapping, edge_channels_list: list[int], activation_fn: ~typing.Callable = <PjitFunction of <function silu>>, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Embeds input node and edge features for the Esen model.

Initializes and applies the embedding layers for node species, radial functions, and spherical harmonics. Updates the input Graph with the embedded features and pre-processes edges and neighbors for subsequent network layers.

graph_cutoff_angstrom

The cutoff radius for the graph.

Type:

float

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

num_species

The number of elements (atomic species descriptors) allowed.

Type:

int

sphere_channels

The number of channels for the node embedding.

Type:

int

radial_envelope

The radial envelope function.

Type:

mlip.models.options.RadialEnvelope | str

radial_basis

The type of radial basis function used.

Type:

mlip.models.options.RadialBasis | str

num_rbf

Number of radial basis functions used in the embedding block.

Type:

int

trainable_rbf

Whether to add learnable weights to each of the radial embedding basis functions.

Type:

bool

edge_channels

The number of channels for the edge embedding.

Type:

int

m_max

The maximum order of the spherical harmonics to include in the embedding.

Type:

int

mapping_reduced

The mapping of the spherical harmonics to the reduced set of coefficients.

Type:

mlip.models.esen.coefficient_mapping.CoefficientMapping

edge_channels_list

The list of channels for the edge embedding.

Type:

list[int]

setup() None

Initializes the embedding layers for node species, radial functions,

__call__(graph: Graph) Graph

Applies the Esen embedding block to compute initial node and edge embeddings.

Embedded features in the final graph can be accessed as: - node features: graph.nodes.features[“embedding”] - edge features: graph.edges.features[“embedding”] - edge envelope: graph.edges.features[“envelope”] - wigner and m mapping: graph.nodes.features[“wigner_and_m_mapping”] - wigner and m mapping inverse: graph.nodes.features[“wigner_and_m_mapping_inv”]

Parameters:

graph – Graph containing node features with “species” and edge vectors.

Returns:

Updated Graph with embedded node and edge features ready for Esen processing.

class mlip.models.esen.layer.ESENLayer(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, mapping_reduced: mlip.models.esen.coefficient_mapping.CoefficientMapping, edge_channels_list: Sequence[int], graph_cutoff_angstrom: float, norm_type: Literal['layer_norm', 'layer_norm_sh', 'rms_norm_sh'], act_type: Literal['gate'], num_experts: int | None = None, deterministic_scatter_ops: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f135c5bb050>, name: Optional[str] = None)
setup() None

Initializes the Esen layer.

__call__(graph: Graph) Graph

Applies the Esen layer to update node features using neighboring nodes’ species and edge features.

Updated features in this function: - node-wise features: graph.nodes.features[“latent”] - edge-wise features: graph.edges.features[“latent”]

Parameters:

graph – Graph containing node features with “latent” or “embedding” if this is a first layer. Same for edge features.

Returns:

Updated Graph with updated node features.

class mlip.models.esen.layer.Edgewise(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, edge_channels_list: ~typing.Sequence[int], mapping_reduced: object, graph_cutoff_angstrom: float, act_type: ~typing.Literal['gate'] = 'gate', num_experts: int | None = None, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Applies the edge-wise convolution to update node features using neighboring nodes’ species and edge features.

sphere_channels

The number of channels for the node embedding.

Type:

int

hidden_channels

The number of channels for the hidden layer.

Type:

int

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

m_max

Maximum order of the spherical harmonics to include in the embedding.

Type:

int

edge_channels_list

The list of channels for the edge embedding.

Type:

Sequence[int]

mapping_reduced

The mapping of the spherical harmonics to the reduced set of coefficients.

Type:

object

graph_cutoff_angstrom

The cutoff radius for the graph.

Type:

float

act_type

The type of activation function to apply.

Type:

Literal[‘gate’]

setup() None

Initializes the edge-wise convolution layers.

__call__(graph: Graph) Graph

Applies the edge-wise convolution to update node features using neighboring nodes’ species and edge features.

Steps:
  1. gather source/target node features -> concat along channels

  2. rotate (align with edge)

  3. SO2 conv 1 -> activation (gate / s2)

  4. SO2 conv 2

  5. apply envelope

  6. rotate back

  7. scatter-add to destination nodes

Updated features in this function: - node-wise features: graph.nodes.features[“latent”]

Parameters:
  • graph – Graph containing node features with “latent” and edge features

  • "latent". (with)

Returns:

Updated Graph with updated node features.

class mlip.models.esen.layer.SpectralAtomwise(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Applies the spectral atom-wise convolution to update node features.

sphere_channels

The number of channels for the node embedding.

Type:

int

hidden_channels

The number of channels for the hidden layer.

Type:

int

l_max

Highest harmonic order included in the Spherical Harmonics series.

Type:

int

m_max

Maximum order of the spherical harmonics to include in the embedding.

Type:

int

setup() None

Initializes the spectral atom-wise layers.

__call__(graph: Graph) Graph

Applies the spectral atom-wise convolution to update node features.

Updated features in this function: - node-wise features: graph.nodes.features[“latent”]

Parameters:

graph – Graph containing node features with “latent”.

Returns:

Updated Graph with updated node features.