e3j.linen.Linear

class e3j.linen.Linear(source_irreps: str, target_irreps: str, channels: tuple[int, int] = (1, 1), layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

E3-equivariant linear channel mixing.

The Linear module mixes equivariant channels with same angular momentum l by:

  • iterating over input slices indexed by l,

  • reshaping each slice x_l from shape (-1, (2l+1) *m_in) to shape (-1, m_in),

  • linearly transforming slices as y_l = x_l @ weights_l,

  • reshaping output slices y_l from shape (-1, m_out) to shape (-1, (2l+1) * m_out),

  • concatenating output slices on axis -1.

Iteration over slices is not parallelised by jax.tree.map.

Methods

__call__(x)

Transform equivariant tensors linearly.

__init__(source_irreps, target_irreps[, ...])

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

join_outputs(ys)

Concatenate outputs, eventually restoring the channel axis.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

slice_inputs(x)

Prepare PyTree of inputs, sliced by irrep.

slice_transform(x_lp, w_lp, block, batch_size)

Mix irreps on a constant-(l,p) block.

slice_weights()

Return arrays of weights acting on all degrees.

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

blocks

channels

kernel_init

name

parent

path

Get the path of this Module.

rescale_gradients

scope

source

target

variables

Returns the variables in this module.

source_irreps

target_irreps

layout

__call__(x: Array) Array

Transform equivariant tensors linearly.

__init__(source_irreps: str, target_irreps: str, channels: tuple[int, int] = (1, 1), layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None) None
join_outputs(ys: list[Array]) Array

Concatenate outputs, eventually restoring the channel axis.

slice_inputs(x: Array) list[Array]

Prepare PyTree of inputs, sliced by irrep.

Parameters:

x (Array) – inputs with shape (-1, num_channels, source.dim)

Returns:

List of 2*(l_max+1) ndarrays of shape

(-1, num_channels, m_in*(2*l_in+1)).

Return type:

list[Array]

slice_transform(x_lp: Array, w_lp: Array, block: LinearBlock, batch_size: int) Array

Mix irreps on a constant-(l,p) block.

The mixing of multiplicities is performed with np.matmul between:

  • w_lp : (N, m_out, m_in), the array of specie-dependent weights,

  • x_lp : (N, m_in * (2l + 1)), the array of equivariant features,

where N is the batch size.

Note

Apart from the normal case m_in * m_out > 0, two different edge cases should be considered which require batch_size to be passed explicitly:

  • if m_out > 0 and m_in == 0, return zero vector with requested multiplicity (contrasts with e3nn).

  • if m_out == 0 return empty vector to be discarded during final concatenation.

It turns out a single branch is enough for both cases, although they are morally different.

slice_weights() list[Array]

Return arrays of weights acting on all degrees.