e3j.linen.Linear¶
- class e3j.linen.Linear(source_irreps: str, target_irreps: str, channels: tuple[int, int] = (1, 1), layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
E3-equivariant linear channel mixing.
The
Linearmodule mixes equivariant channels with same angular momentumlby:iterating over input slices indexed by
l,reshaping each slice
x_lfrom shape(-1, (2l+1) *m_in)to shape(-1, m_in),linearly transforming slices as
y_l = x_l @ weights_l,reshaping output slices
y_lfrom shape(-1, m_out)to shape(-1, (2l+1) * m_out),concatenating output slices on axis -1.
Iteration over slices is not parallelised by
jax.tree.map.Methods
__call__(x)Transform equivariant tensors linearly.
__init__(source_irreps, target_irreps[, ...])apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.join_outputs(ys)Concatenate outputs, eventually restoring the channel axis.
lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).slice_inputs(x)Prepare PyTree of inputs, sliced by irrep.
slice_transform(x_lp, w_lp, block, batch_size)Mix irreps on a constant-(l,p) block.
Return arrays of weights acting on all degrees.
sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
blockschannelskernel_initnameparentpathGet the path of this Module.
rescale_gradientsscopesourcetargetvariablesReturns the variables in this module.
source_irrepstarget_irrepslayout- __call__(x: Array) Array¶
Transform equivariant tensors linearly.
- __init__(source_irreps: str, target_irreps: str, channels: tuple[int, int] = (1, 1), layout: str | ~e3j.utils.options.Layout = <factory>, kernel_init: ~jax.nn.initializers.Initializer | ~e3j.utils.options.LinearInitialization | str = 'FAN_IN', rescale_gradients: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None) None¶
- join_outputs(ys: list[Array]) Array¶
Concatenate outputs, eventually restoring the channel axis.
- slice_inputs(x: Array) list[Array]¶
Prepare PyTree of inputs, sliced by irrep.
- Parameters:
x (Array) – inputs with shape (-1, num_channels, source.dim)
- Returns:
- List of 2*(l_max+1) ndarrays of shape
(-1, num_channels, m_in*(2*l_in+1)).
- Return type:
list[Array]
- slice_transform(x_lp: Array, w_lp: Array, block: LinearBlock, batch_size: int) Array¶
Mix irreps on a constant-(l,p) block.
The mixing of multiplicities is performed with
np.matmulbetween:w_lp : (N, m_out, m_in), the array of specie-dependent weights,x_lp : (N, m_in * (2l + 1)), the array of equivariant features,
where
Nis the batch size.Note
Apart from the normal case
m_in * m_out > 0, two different edge cases should be considered which requirebatch_sizeto be passed explicitly:if
m_out > 0andm_in == 0, return zero vector with requested multiplicity (contrasts with e3nn).if
m_out == 0return empty vector to be discarded during final concatenation.
It turns out a single branch is enough for both cases, although they are morally different.
- slice_weights() list[Array]¶
Return arrays of weights acting on all degrees.