Readout blocks¶
- class mlip.models.readout.MultiHeadReadoutBlock(num_heads: int, features: ~typing.Sequence[int | ~e3nn_jax._src.irreps.Irreps], activation: ~typing.Callable | None = None, mlp_kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit | None = None, use_equiv: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Multi-head readout block that can be used in multiple models.
This block shares its attributes mostly with the single-head
ReadoutBlock, however, it has the additionalnum_headsattribute.Applies the readout block independently across multiple heads.
- num_heads¶
The number of readout heads.
- Type:
int
- features¶
The feature output dimensions for each linear layer as a sequence of either integers of irreps in the equivariant case.
- Type:
Sequence[int | e3nn_jax._src.irreps.Irreps]
- activation¶
The activation function to use. Default is
Nonewhich means no activation is applied.- Type:
Callable | None
- mlp_kernel_init¶
The kernel initialization method to use if using a non-equivariant MLP. Default is
None.- Type:
jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit | None
- use_equiv¶
Whether to use an equivariant linear layer and assume the features are irreps. Default is false.
- Type:
bool
- __call__(graph: Graph) Graph¶
Call function for this block.
Note that the outputs of this block will have the dimension
[num_nodes, num_heads, num_final_readout_layer_output_features].- Parameters:
graph – The input graph. Should have its features in
graph.nodes.features["latent"].- Returns:
- The output graph with the resulting node readout outputs in
graph.nodes.features["outputs"].
- Return type:
graph
- class mlip.models.readout.ReadoutBlock(features: ~typing.Sequence[int | ~e3nn_jax._src.irreps.Irreps], activation: ~typing.Callable | None = None, mlp_kernel_init: ~jax.nn.initializers.Initializer | ~mlip.models.options.GradientScaledKernelInit | None = None, use_equiv: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Readout block that can be used in multiple models.
The block only has a single head, see
MultiHeadReadoutBlockfor the multi-head version.The readout consists of a couple of linear layers (either equivariant or not) and activations in between (if requested).
- features¶
The feature output dimensions for each linear layer as a sequence of either integers of irreps in the equivariant case.
- Type:
Sequence[int | e3nn_jax._src.irreps.Irreps]
- activation¶
The activation function to use. Default is
Nonewhich means no activation is applied.- Type:
Callable | None
- mlp_kernel_init¶
The kernel initialization method to use if using a non-equivariant MLP. Default is
None.- Type:
jax.nn.initializers.Initializer | mlip.models.options.GradientScaledKernelInit | None
- use_equiv¶
Whether to use an equivariant linear layer and assume the features are irreps. Default is false.
- Type:
bool
- mlip.models.readout.select_head(graph: Graph) Graph¶
Select a readout head per node based on
graph.globals.dataset_idx.Expects the graph to have “outputs” in
graph.nodes.featureswith shape[num_nodes, num_heads, num_predictions]. Selects one head per node and updates “outputs” to shape[num_nodes, num_predictions].In a batch, different graphs may target different heads. The per-graph
dataset_idxis broadcast to per-node usinggraph.n_node.When
dataset_idxisNone, head 0 is used.- Parameters:
graph – The input graph with “outputs” feature and
globals.dataset_idx.- Returns:
The graph with “outputs” of shape
[num_nodes, num_predictions].