e3j.linen.LinearIndexwise

class e3j.linen.LinearIndexwise(source_irreps: str, target_irreps: str, num_indices: int, num_channels: int | None = None, layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearIndexwiseInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

E3-equivariant specie-wise linear mixing.

Weights are learned for each of the C independent channels and I distinct indices, where C is inferred during init.

Each linear block acts on a reducible representation of constant momentum l and parity p. This means that for every (l, p) pair, the module carries a weight array

w_lp : (I, C, m_out, m_in)

where m_in and m_out denote the multiplicities of (l, p) in the source and target representations respectively.

Source and target irreps may have mismatching (l, p) blocks: target blocks absent from source produce zero outputs, and source blocks absent from target are discarded. This is useful e.g. for skip-connections from a first scalar-only layer in MLIPs, where the source space is "Kx0e" and the target includes higher-order irreps.

Note

Iteration over slices may be parallelised by the jit compiler, although it is written as a simple list comprehension. This is what jax.tree.map anyway ends up being lowered to, see: https://github.com/google/jax/issues/11394

Methods

__call__(x_feats, x_indices)

Transform x_feats linearly with x_indices-dependent weights.

__init__(source_irreps, target_irreps, ...)

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

join_outputs(ys)

Concatenate outputs, restoring the channel axis layout.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initialize weights from list of LinearBlock descriptors.

slice_inputs(x)

Prepare PyTree of inputs, sliced by irrep.

slice_transform(x_lp, w_lp, block, batch_size)

Mix irreps on a constant-(l,p) block.

slice_weights(indices)

Return specie-wise weights acting on x_feats.

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

blocks

Return list of LinearBlock descriptors acting on irreducibles.

kernel_init

name

num_channels

parent

path

Get the path of this Module.

rescale_gradients

scope

source

target

variables

Returns the variables in this module.

source_irreps

target_irreps

num_indices

layout

__call__(x_feats: Array, x_indices: Array) Array

Transform x_feats linearly with x_indices-dependent weights.

Parameters:
  • x_feats (jnp.ndarray) – (N, C, source.dim)-array of equivariant features (leading channels) or (N, source.dim, C) (trailing channels).

  • x_indices (jnp.ndarray) – N-vector of specie indices (positive and lower than I = num_indices)

Returns:

Array of shape (N, C, target.dim) or

(N, target.dim, C) depending on layout.

Return type:

jnp.ndarray

__init__(source_irreps: str, target_irreps: str, num_indices: int, num_channels: int | None = None, layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearIndexwiseInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None) None
property blocks: list[LinearBlock]

Return list of LinearBlock descriptors acting on irreducibles.

Source and target may have mismatching (l, p) blocks. Target blocks absent from source produce zero outputs, which is useful e.g. for skip-connections from a first scalar-only layer in MLIPs.

join_outputs(ys: list[Array]) Array

Concatenate outputs, restoring the channel axis layout.

setup()

Initialize weights from list of LinearBlock descriptors.

slice_inputs(x: Array) list[Array]

Prepare PyTree of inputs, sliced by irrep.

Returns a list of arrays with shape (..., C, m_in * (2l+1)) (leading channels) regardless of input layout. For trailing channels, axis swap is applied after slicing.

slice_transform(x_lp: Array, w_lp: Array, block: LinearBlock, batch_size: int, num_channels: int | None = None) Array

Mix irreps on a constant-(l,p) block.

The mixing of multiplicities is performed with np.matmul between:

  • w_lp : (N, C, m_out, m_in), the array of specie-dependent weights,

  • x_lp : (N, C, m_in * (2l + 1)), the array of equivariant features,

where N is the batch size and C the number of independent channels.

Note

Apart from the normal case m_in * m_out > 0, two different edge cases should be considered:

  • if m_out > 0 and m_in == 0,

    return zero vector with requested multiplicity (in contrast with e3nn).

  • if m_out == 0,

    return empty vector to be discarded during concatenation.

It turns out a single branch is enough for both cases, although they are morally different. They require batch_size and num_channels to be passed explicitly.

slice_weights(indices: Array) list[Array]

Return specie-wise weights acting on x_feats.

The learnable weights arrays, each of shape (I, C, m_in, m_out), are pulled by the length-N integer vector x_indices < I to produce batches of weight arrays (N, C, m_in, m_out).

The I/O multiplicities depend both on momentum l and parity p.