e3j.linen.LinearIndexwise¶
- class e3j.linen.LinearIndexwise(source_irreps: str, target_irreps: str, num_indices: int, num_channels: int | None = None, layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearIndexwiseInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
E3-equivariant specie-wise linear mixing.
Weights are learned for each of the
Cindependent channels andIdistinct indices, whereCis inferred duringinit.Each linear block acts on a reducible representation of constant momentum
land parityp. This means that for every(l, p)pair, the module carries a weight arrayw_lp : (I, C, m_out, m_in)
where
m_inandm_outdenote the multiplicities of(l, p)in the source and target representations respectively.Source and target irreps may have mismatching (l, p) blocks: target blocks absent from source produce zero outputs, and source blocks absent from target are discarded. This is useful e.g. for skip-connections from a first scalar-only layer in MLIPs, where the source space is
"Kx0e"and the target includes higher-order irreps.Note
Iteration over slices may be parallelised by the jit compiler, although it is written as a simple list comprehension. This is what
jax.tree.mapanyway ends up being lowered to, see: https://github.com/google/jax/issues/11394Methods
__call__(x_feats, x_indices)Transform
x_featslinearly withx_indices-dependent weights.__init__(source_irreps, target_irreps, ...)apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.join_outputs(ys)Concatenate outputs, restoring the channel axis layout.
lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initialize weights from list of
LinearBlockdescriptors.slice_inputs(x)Prepare PyTree of inputs, sliced by irrep.
slice_transform(x_lp, w_lp, block, batch_size)Mix irreps on a constant-(l,p) block.
slice_weights(indices)Return specie-wise weights acting on
x_feats.sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
Return list of
LinearBlockdescriptors acting on irreducibles.kernel_initnamenum_channelsparentpathGet the path of this Module.
rescale_gradientsscopesourcetargetvariablesReturns the variables in this module.
source_irrepstarget_irrepsnum_indiceslayout- __call__(x_feats: Array, x_indices: Array) Array¶
Transform
x_featslinearly withx_indices-dependent weights.- Parameters:
x_feats (jnp.ndarray) –
(N, C, source.dim)-array of equivariant features (leading channels) or(N, source.dim, C)(trailing channels).x_indices (jnp.ndarray) –
N-vector of specie indices (positive and lower thanI = num_indices)
- Returns:
- Array of shape
(N, C, target.dim)or (N, target.dim, C)depending on layout.
- Array of shape
- Return type:
jnp.ndarray
- __init__(source_irreps: str, target_irreps: str, num_indices: int, num_channels: int | None = None, layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearIndexwiseInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None) None¶
- property blocks: list[LinearBlock]¶
Return list of
LinearBlockdescriptors acting on irreducibles.Source and target may have mismatching (l, p) blocks. Target blocks absent from source produce zero outputs, which is useful e.g. for skip-connections from a first scalar-only layer in MLIPs.
- join_outputs(ys: list[Array]) Array¶
Concatenate outputs, restoring the channel axis layout.
- setup()¶
Initialize weights from list of
LinearBlockdescriptors.
- slice_inputs(x: Array) list[Array]¶
Prepare PyTree of inputs, sliced by irrep.
Returns a list of arrays with shape
(..., C, m_in * (2l+1))(leading channels) regardless of input layout. For trailing channels, axis swap is applied after slicing.
- slice_transform(x_lp: Array, w_lp: Array, block: LinearBlock, batch_size: int, num_channels: int | None = None) Array¶
Mix irreps on a constant-(l,p) block.
The mixing of multiplicities is performed with
np.matmulbetween:w_lp : (N, C, m_out, m_in), the array of specie-dependent weights,x_lp : (N, C, m_in * (2l + 1)), the array of equivariant features,
where
Nis the batch size andCthe number of independent channels.Note
Apart from the normal case
m_in * m_out > 0, two different edge cases should be considered:- if
m_out > 0andm_in == 0, return zero vector with requested multiplicity (in contrast with e3nn).
- if
- if
m_out == 0, return empty vector to be discarded during concatenation.
- if
It turns out a single branch is enough for both cases, although they are morally different. They require
batch_sizeandnum_channelsto be passed explicitly.
- slice_weights(indices: Array) list[Array]¶
Return specie-wise weights acting on
x_feats.The learnable weights arrays, each of shape
(I, C, m_in, m_out), are pulled by the length-N integer vectorx_indices < Ito produce batches of weight arrays(N, C, m_in, m_out).The I/O multiplicities depend both on momentum
land parityp.