Readout blocks

class mlip.models.readout.MultiHeadReadoutBlock(num_heads: int, features: ~typing.Sequence[int | ~e3nn_jax._src.irreps.Irreps], activation: ~typing.Callable | None = None, mlp_kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit | None = None, use_equiv: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Multi-head readout block that can be used in multiple models.

This block shares its attributes mostly with the single-head ReadoutBlock, however, it has the additional num_heads attribute.

Applies the readout block independently across multiple heads.

num_heads

The number of readout heads.

Type:

int

features

The feature output dimensions for each linear layer as a sequence of either integers of irreps in the equivariant case.

Type:

Sequence[int | e3nn_jax._src.irreps.Irreps]

activation

The activation function to use. Default is None which means no activation is applied.

Type:

Callable | None

mlp_kernel_init

The kernel initialization method to use if using a non-equivariant MLP. Default is None.

Type:

jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit | None

use_equiv

Whether to use an equivariant linear layer and assume the features are irreps. Default is false.

Type:

bool

__call__(graph: Graph) Graph

Call function for this block.

Note that the outputs of this block will have the dimension [num_nodes, num_heads, num_final_readout_layer_output_features].

Parameters:

graph – The input graph. Should have its features in graph.nodes.features["latent"].

Returns:

The output graph with the resulting node readout outputs in

graph.nodes.features["outputs"].

Return type:

graph

class mlip.models.readout.ReadoutBlock(features: ~typing.Sequence[int | ~e3nn_jax._src.irreps.Irreps], activation: ~typing.Callable | None = None, mlp_kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit | None = None, use_equiv: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Readout block that can be used in multiple models.

The block only has a single head, see MultiHeadReadoutBlock for the multi-head version.

The readout consists of a couple of linear layers (either equivariant or not) and activations in between (if requested).

features

The feature output dimensions for each linear layer as a sequence of either integers of irreps in the equivariant case.

Type:

Sequence[int | e3nn_jax._src.irreps.Irreps]

activation

The activation function to use. Default is None which means no activation is applied.

Type:

Callable | None

mlp_kernel_init

The kernel initialization method to use if using a non-equivariant MLP. Default is None.

Type:

jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit | None

use_equiv

Whether to use an equivariant linear layer and assume the features are irreps. Default is false.

Type:

bool

__call__(graph: Graph) Graph

Call function for this block.

Parameters:

graph – The input graph. Should have its features in graph.nodes.features["latent"].

Returns:

The output graph with the resulting node readout outputs in

graph.nodes.features["outputs"].

Return type:

graph

mlip.models.readout.select_head(graph: Graph) Graph

Select a readout head per node based on graph.globals.dataset_idx.

Expects the graph to have “outputs” in graph.nodes.features with shape [num_nodes, num_heads, num_predictions]. Selects one head per node and updates “outputs” to shape [num_nodes, num_predictions].

In a batch, different graphs may target different heads. The per-graph dataset_idx is broadcast to per-node using graph.n_node.

When dataset_idx is None, head 0 is used.

Parameters:

graph – The input graph with “outputs” feature and globals.dataset_idx.

Returns:

The graph with “outputs” of shape [num_nodes, num_predictions].