ESEN¶
- class mlip.models.esen.network.Esen(config: EsenConfig, dataset_info: DatasetInfo, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
The Esen model flax module. It is derived from the
MLIPNetworkclass.References
Saro Passaro, Lawrence Zitnick. Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs. URL: https://arxiv.org/pdf/2302.03655
Xiang Fu, Brandon M. Wood, Luis Barroso-Luque, Daniel S. Levine, Meng Gao, Misko Dzamba, C. Lawrence Zitnick. Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction. URL: https://arxiv.org/abs/2502.12147
Brandon M. Wood et al. UMA: A Family of Universal Models for Atoms. URL: https://arxiv.org/pdf/2506.23971
- config¶
Hyperparameters / configuration for the Esen model, see
EsenConfig.
- dataset_info¶
Hyperparameters dictated by the dataset (e.g., cutoff radius or average number of neighbors).
- available_properties¶
Model available properties, see
Properties.
- setup() None¶
Initializes the model layers.
- __call__(graph: Graph) Graph¶
Runs the Esen model forward pass on an input Graph and returns an updated Graph object with node-wise contributions.
Features in output graph: - node-wise energy : graph.nodes.features[“energy”]
- Parameters:
graph – Input Graph containing atomic positions and topology.
- Returns:
Updated Graph with node-wise features.
- class mlip.models.esen.config.EsenConfig(*, add_atomic_energies: bool = True, num_species: int | None = None, num_layers: int = 2, sphere_channels: int = 16, hidden_channels: int = 16, edge_channels: int = 16, l_max: int = 2, m_max: int = 2, radial_envelope: RadialEnvelope = RadialEnvelope.POLYNOMIAL, radial_basis: str | RadialBasis = RadialBasis.GAUSS, num_rbf: int = 16, basis_width_scalar: float = 2.0, cosine_cutoff: bool = False, norm_type: str = 'rms_norm_sh', act_type: str = 'gate', trainable_rbf: bool = False, num_readout_heads: Annotated[int, Gt(gt=0)] = 1, moe: EsenMoEConfig | None = None, predict_partial_charges: bool = False, use_coulomb_term: bool = False, use_total_charge_embedding: bool = False, embed_activation: Activation = Activation.SILU, deterministic_scatter_ops: bool = False)¶
The configuration / hyperparameters of the eSEN model.
- num_species¶
The number of elements (atomic species descriptors) allowed. If
None(default), infer the value from the atomic energies map in the dataset info.- Type:
int | None
- num_layers¶
Number of eSEN layers. Default is 4.
- Type:
int
- sphere_channels¶
The number of channels for the node embedding. Default is 128.
- Type:
int
The number of channels outputs for convolution layers and MLPs in the network. Default is 128.
- Type:
int
- edge_channels¶
The number of channels for the edge embedding. Default is 128.
- Type:
int
- l_max¶
Highest degree of spherical harmonics used for the directional encoding of edge vectors, and during the convolution block. Default is 2.
- Type:
int
- m_max¶
Cap on m number in the convolution layer, m features above that order are removed. Default is 2.
- Type:
int
- add_atomic_energies¶
Whether to add atomic energies to the final energies. Default is
True.- Type:
bool
- radial_envelope¶
The radial envelope function, by default it is
"polynomial_envelope". The only other option is"soft_envelope".
- radial_basis¶
Type of radial basis function used. Two options available: “bessel”, “gauss”, and “expnorm”. Default is “gauss”.
- Type:
- num_rbf¶
Number of radial basis used in edge embedding. Default is set to 32 (512 used in UMA small)
- Type:
int
- cosine_cutoff¶
Whether to use the cosine cutoff envelope function in the radial embedding block. Defaults to False.
- Type:
bool
- norm_type¶
Specifies the type of normalisation used. Three options are available: “layer_norm”, “layer_norm_sh”, “rms_norm_sh”. Default is “rms_norm_sh”
- Type:
str
- act_type¶
Activation type for Edgewise (convolution). Only one option available, “gate”, used a default.
- Type:
str
- num_readout_heads¶
Number of readout heads. The default is 1.
- Type:
int
- moe¶
Optional MoE configuration. When
None(default), eSEN behaves as a standard model with no expert routing.- Type:
mlip.models.esen.config.EsenMoEConfig | None
- predict_partial_charges¶
Whether the model will be trained to predict charges.
- Type:
bool
- use_coulomb_term¶
Whether to use the Coulomb term in the model for long range interactions. Default is False.
- Type:
bool
- use_total_charge_embedding¶
Whether to use the total charge embedding. Default is False.
- Type:
bool
- embed_activation¶
Activation function for the embedding block. Default is “silu”.
- deterministic_scatter_ops¶
Whether to use deterministic scatter operations in the forward pass, ensuring deterministic energy outputs. Setting to
Truemakes prediction slower. Default isFalse.- Type:
bool
- class mlip.models.esen.blocks.EsenEmbeddingBlock(graph_cutoff_angstrom: float, l_max: int, num_species: int, num_charges: int | None, sphere_channels: int, radial_envelope: ~mlip.models.options.RadialEnvelope | str, radial_basis: ~mlip.models.options.RadialBasis | str, num_rbf: int, basis_width_scalar: float, cosine_cutoff: bool, trainable_rbf: bool, edge_channels: int, m_max: int, mapping_reduced: ~mlip.models.esen.coefficient_mapping.CoefficientMapping, edge_channels_list: list[int], activation_fn: ~typing.Callable = <PjitFunction of <function silu>>, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Embeds input node and edge features for the Esen model.
Initializes and applies the embedding layers for node species, radial functions, and spherical harmonics. Updates the input Graph with the embedded features and pre-processes edges and neighbors for subsequent network layers.
- graph_cutoff_angstrom¶
The cutoff radius for the graph.
- Type:
float
- l_max¶
Highest harmonic order included in the Spherical Harmonics series.
- Type:
int
- num_species¶
The number of elements (atomic species descriptors) allowed.
- Type:
int
- sphere_channels¶
The number of channels for the node embedding.
- Type:
int
- radial_envelope¶
The radial envelope function.
- Type:
- radial_basis¶
The type of radial basis function used.
- Type:
- num_rbf¶
Number of radial basis functions used in the embedding block.
- Type:
int
- trainable_rbf¶
Whether to add learnable weights to each of the radial embedding basis functions.
- Type:
bool
- edge_channels¶
The number of channels for the edge embedding.
- Type:
int
- m_max¶
The maximum order of the spherical harmonics to include in the embedding.
- Type:
int
- mapping_reduced¶
The mapping of the spherical harmonics to the reduced set of coefficients.
- Type:
mlip.models.esen.coefficient_mapping.CoefficientMapping
- edge_channels_list¶
The list of channels for the edge embedding.
- Type:
list[int]
- setup() None¶
Initializes the embedding layers for node species, radial functions,
- __call__(graph: Graph) Graph¶
Applies the Esen embedding block to compute initial node and edge embeddings.
Embedded features in the final graph can be accessed as: - node features: graph.nodes.features[“embedding”] - edge features: graph.edges.features[“embedding”] - edge envelope: graph.edges.features[“envelope”] - wigner and m mapping: graph.nodes.features[“wigner_and_m_mapping”] - wigner and m mapping inverse: graph.nodes.features[“wigner_and_m_mapping_inv”]
- Parameters:
graph – Graph containing node features with “species” and edge vectors.
- Returns:
Updated Graph with embedded node and edge features ready for Esen processing.
- class mlip.models.esen.layer.ESENLayer(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, mapping_reduced: mlip.models.esen.coefficient_mapping.CoefficientMapping, edge_channels_list: Sequence[int], graph_cutoff_angstrom: float, norm_type: Literal['layer_norm', 'layer_norm_sh', 'rms_norm_sh'], act_type: Literal['gate'], num_experts: int | None = None, deterministic_scatter_ops: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f135c5bb050>, name: Optional[str] = None)¶
- setup() None¶
Initializes the Esen layer.
- __call__(graph: Graph) Graph¶
Applies the Esen layer to update node features using neighboring nodes’ species and edge features.
Updated features in this function: - node-wise features: graph.nodes.features[“latent”] - edge-wise features: graph.edges.features[“latent”]
- Parameters:
graph – Graph containing node features with “latent” or “embedding” if this is a first layer. Same for edge features.
- Returns:
Updated Graph with updated node features.
- class mlip.models.esen.layer.Edgewise(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, edge_channels_list: ~typing.Sequence[int], mapping_reduced: object, graph_cutoff_angstrom: float, act_type: ~typing.Literal['gate'] = 'gate', num_experts: int | None = None, deterministic_scatter_ops: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Applies the edge-wise convolution to update node features using neighboring nodes’ species and edge features.
- sphere_channels¶
The number of channels for the node embedding.
- Type:
int
The number of channels for the hidden layer.
- Type:
int
- l_max¶
Highest harmonic order included in the Spherical Harmonics series.
- Type:
int
- m_max¶
Maximum order of the spherical harmonics to include in the embedding.
- Type:
int
- edge_channels_list¶
The list of channels for the edge embedding.
- Type:
Sequence[int]
- mapping_reduced¶
The mapping of the spherical harmonics to the reduced set of coefficients.
- Type:
object
- graph_cutoff_angstrom¶
The cutoff radius for the graph.
- Type:
float
- act_type¶
The type of activation function to apply.
- Type:
Literal[‘gate’]
- setup() None¶
Initializes the edge-wise convolution layers.
- __call__(graph: Graph) Graph¶
Applies the edge-wise convolution to update node features using neighboring nodes’ species and edge features.
- Steps:
gather source/target node features -> concat along channels
rotate (align with edge)
SO2 conv 1 -> activation (gate / s2)
SO2 conv 2
apply envelope
rotate back
scatter-add to destination nodes
Updated features in this function: - node-wise features: graph.nodes.features[“latent”]
- Parameters:
graph – Graph containing node features with “latent” and edge features
"latent". (with)
- Returns:
Updated Graph with updated node features.
- class mlip.models.esen.layer.SpectralAtomwise(sphere_channels: int, hidden_channels: int, l_max: int, m_max: int, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Applies the spectral atom-wise convolution to update node features.
- sphere_channels¶
The number of channels for the node embedding.
- Type:
int
The number of channels for the hidden layer.
- Type:
int
- l_max¶
Highest harmonic order included in the Spherical Harmonics series.
- Type:
int
- m_max¶
Maximum order of the spherical harmonics to include in the embedding.
- Type:
int
- setup() None¶
Initializes the spectral atom-wise layers.
- __call__(graph: Graph) Graph¶
Applies the spectral atom-wise convolution to update node features.
Updated features in this function: - node-wise features: graph.nodes.features[“latent”]
- Parameters:
graph – Graph containing node features with “latent”.
- Returns:
Updated Graph with updated node features.