Blocks¶
- class mlip.models.blocks.SpeciesAssignmentBlock(dataset_info: DatasetInfo)¶
Map atomic numbers to contiguous species indices.
Builds a compile-time lookup table from the allowed atomic numbers in
dataset_infoand uses it to populategraph.nodes.features["species"]with zero-based species indices. Atomic numbers not present in the dataset are mapped toSPECIES_PLACEHOLDER.- dataset_info¶
The model’s dataset_info whose
allowed_atomic_numbersdefines the supported species.
- class mlip.models.blocks.AtomicEnergiesBlock(dataset_info: DatasetInfo, learnable: bool = False, skip_atomic_energies_addition: bool = False, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Add atomic energies to latent node energy summands.
Typical atomic contributions (usually, the energy of core electrons) are initialized from the
dataset_info. By default, they are not learnable. If thescaling_meanandscaling_stdevattributes are set in dataset_info, latent node features will also be shifted and rescaled prior to the addition of atomic energies.- dataset_info¶
The model’s dataset_info containing the dictionary of atomic energies.
- learnable¶
Whether to allow atomic energies to be learned, the default is false.
- Type:
bool
- skip_atomic_energies_addition¶
Whether the atomic energies addition should be skipped and thus only the shifting/scaling is applied. Default is
False.- Type:
bool
- class mlip.models.blocks.RadialEmbeddingBlock(radial_basis: RadialBasis, num_rbf: int, graph_cutoff_angstrom: float, learnable: bool, radial_envelope: RadialEnvelope | None = None, avg_r_min: float | None = None, return_as_irreps: bool = False, basis_width_scalar: float = 1.0, cosine_cutoff: bool = True, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Transforms distances into feature vectors using radial basis functions (RBFs).
- radial_basis¶
Type of radial basis functions to use (e.g., Gaussian smearing, ExpNormal smearing, or Bessel).
- num_rbf¶
Number of radial basis functions.
- Type:
int
- graph_cutoff_angstrom¶
Cutoff distance beyond which interactions are ignored or smoothly suppressed.
- Type:
float
- learnable¶
If True, the parameters of the radial basis functions are learnable. Note that “Bessel” RBF type is not learnable.
- Type:
bool
- radial_envelope¶
Optional envelope function applied to the radial embeddings to enforce smooth cutoff behavior. If None, no additional envelope is applied, which is the default.
- Type:
- avg_r_min¶
Optional minimum average distance used for normalization or scaling of the radial features.
- Type:
float | None
- return_as_irreps¶
If True, returns the embedding formatted as
e3nn_jax.IrrepsArraytype. Default is False.- Type:
bool
- basis_width_scalar¶
Only used in Gaussian smearing, scaling factor applied to the width of the radial basis functions.
- Type:
float
- __call__(distances: Array) Array¶
Call function for the radial embedding block.
- Parameters:
distances – The distances for all the edges, i.e., length of the edge vectors.
- Returns:
The radial embeddings for each edge.
- class mlip.models.blocks.SphericalHarmonicsBlock(l_max: int, normalize: bool = True, normalization: str | None = None)¶
Spherical harmonics encoding of edge vectors.
- __call__(edge_vectors: Array) IrrepsArray¶
Call self as a function.
- class mlip.models.blocks.NodeEmbeddingBlock(num_species: int, num_channels: int, kernel_init: Initializer | GradientScaledKernelInit = <function variance_scaling.<locals>.init>, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Species embedding using a learned lookup table.
Maps integer species indices to dense embedding vectors via
nn.Embed. Returns a plain jax array. Models that require ane3nn.IrrepsArrayshould wrap the output themselves.- num_species¶
Number of distinct atomic species.
- Type:
int
- num_channels¶
Embedding dimension.
- Type:
int
- kernel_init¶
Initializer type for the embedding. Can be a
GradientScaledKernelInitor a callable initializer. Default is the default init used bynn.Embed.- Type:
jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit
- __call__(node_species: Array) Array¶
Call self as a function.
- class mlip.models.blocks.MaskPaddedNodeOutputsBlock(feature_names: tuple[str, ...], parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Zero out the listed node features on padded (dummy) nodes.
Padded graphs carry dummy nodes that must not contribute to downstream quantities (energy, partial charges, …). This block applies
jnp.where(graph.node_mask(), feature, 0.0)for each name infeature_names. Centralises a pattern previously inlined at the end of each model’s__call__.- feature_names¶
Tuple of node feature keys to mask.
- Type:
tuple[str, …]
- class mlip.models.blocks.MLP(layer_sizes: list[int], activation: ~mlip.models.options.Activation | ~typing.Callable[[~jax.Array], ~jax.Array] | str, kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit, use_bias: bool = False, use_layer_norm: bool = False, normalize_activation: bool = False, output_kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Multi-Layer Perceptron for scalar features, e.g. radial embeddings.
Note that the activation function is only applied between hidden layers, not after the output layer.
- layer_sizes¶
List of dimensions for each layer. Must include input and output dimensions. For example, [4, 16, 32] has 2 layers with input dimension 4.
- Type:
list[int]
- activation¶
Activation function to apply between layers.
- Type:
mlip.models.options.Activation | Callable[[jax.Array], jax.Array] | str
- kernel_init¶
Initializer type for the hidden layers. Can be a
GradientScaledKernelInitor a callable initializer. Default islecun_normal, to match defaultnn.Denseinitializer.- Type:
jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit
- use_bias¶
Whether to include bias parameters in all layers.
- Type:
bool
- use_layer_norm¶
If true, use LayerNorm in each hidden layer.
- Type:
bool
- normalize_activation¶
If true, wrap activation with e3nn.normalize_function.
- Type:
bool
- output_kernel_init¶
Initializer for the output layer. Default is
kernel_init.- Type:
jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit | None
- __call__(x: Array) Array¶
Call self as a function.