Graph

class mlip.graph.Graph(nodes: GraphNodes, edges: GraphEdges, globals: GraphGlobals, n_node: Array | ndarray | bool | number | bool | int | float | complex, n_edge: Array | ndarray | bool | number | bool | int | float | complex, senders: Array | ndarray | bool | number | bool | int | float | complex, receivers: Array | ndarray | bool | number | bool | int | float | complex, n_edge_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, senders_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, receivers_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, edges_long_range: GraphEdges | None = None)

The Graph class defining a single graph or a batch of graphs.

Modeled after jraph.GraphsTuple, but with additional methods.

nodes

The node features of the graph.

Type:

mlip.graph.graph.GraphNodes

edges

The edge features of the graph.

Type:

mlip.graph.graph.GraphEdges

globals

The global features of the graph.

Type:

mlip.graph.graph.GraphGlobals

n_node

The number of nodes in the graph (or a vector if a batch).

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

n_edge

The number of edges in the graph (or a vector if a batch).

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

senders

The sender indices of the edges of the graph.

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

receivers

The receiver indices of the edges of the graph.

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

n_edge_long_range

The number of long range edges in the graph.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

senders_long_range

The sender indices of the long range edges.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

receivers_long_range

The receiver indices of the long range edges.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

edges_long_range

Edge information for long-range edges if present.

Type:

mlip.graph.graph.GraphEdges | None

classmethod from_chemical_system(chemical_system: ChemicalSystem, graph_cutoff_angstrom: float, long_range_cutoff_angstrom: float | None = None) Self

Create a Graph object from a chemical system and dataset info.

This includes computing the senders/receivers/shifts for the system and otherwise just transferring data 1-to-1 to the graph.

Parameters:
  • chemical_system – The chemical system object.

  • graph_cutoff_angstrom – The graph distance cutoff in Angstrom.

  • long_range_cutoff_angstrom – The long range distance cutoff in Angstrom. If None, long range interactions are not computed.

Returns:

The Graph object for the given chemical system.

num_graphs

Number of graphs in the (possibly batched, possibly padded) graph.

replace_nodes(**kwargs) Self

Returns the Graph object where nodes attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the nodes dataclass.

replace_edges(**kwargs) Self

Returns the Graph object where edges attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the edges dataclass.

replace_globals(**kwargs) Self

Returns the Graph object where globals attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the globals dataclass.

update_node_features(**kwargs) Self

Returns the Graph object where nodes attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the nodes dataclass.

update_edge_features(**kwargs) Self

Returns the Graph object where edges attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the nodes dataclass.

update_global_features(**kwargs) Self

Returns the Graph object where globals attribute are replaced.

Keyword arguments are forwarded to the .replace() call on the globals dataclass.

request_full_hessian() Self

Returns a graph that has sample_hessian_rows=np.array(True) in the globals.

Required for inference pipelines.

node_mask() Array

Evaluates the node padding mask array for the graph.

True refers to a real node, while False refers to a dummy node in the (batched) graph.

Returns:

The node padding mask.

graph_mask() Array

Evaluates the graph padding mask array for the batched graph.

True refers to a real graph, while False refers to a dummy graph in the batched graph.

Returns:

The graph padding mask.

to_prediction() Prediction

Creates a Prediction object from the current graph, which contains all the properties of the graph that are also part of a prediction.

Returns:

The prediction.

edge_vectors(use_np: bool = False) Array | ndarray | bool | number | bool | int | float | complex

Compute the relative edge vectors from senders to receivers.

We use displ_fun if available, otherwise edge vectors are computed directly using the positions and shifts. In the case of PBCs, sender nodes are translated from the unit cell to the receiver’s nearest neighbouring cell:

# If `displ_fun` is None, this method returns:
vectors = positions[receivers] - (positions[senders] - shifts @ cell)

# Equivalent to this line from the ASE docs:
D = positions[j] - positions[i] + S.dot(cell)
Parameters:

use_np – Whether to use numpy or jax.numpy for the computation. Default is False, which means jax.numpy is used.

Returns:

The relative edge vectors, labelled D by ASE.

long_range_edge_vectors(use_np: bool = False) Array | ndarray | bool | number | bool | int | float | complex

Compute the relative long-range edge vectors from senders to receivers.

Mirrors self.edge_vectors() for the long range graph.

Parameters:

use_np – Whether to use numpy or jax.numpy for the computation. Default is False, which means jax.numpy is used.

Returns:

The relative long-range edge vectors.

class mlip.graph.GraphNodes(atomic_numbers: float | Array | ndarray | bool | number | bool | int | complex | None = None, positions: float | Array | ndarray | bool | number | bool | int | complex | None = None, forces: float | Array | ndarray | bool | number | bool | int | complex | None = None, partial_charges: float | Array | ndarray | bool | number | bool | int | complex | None = None, hessian: float | Array | ndarray | bool | number | bool | int | complex | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray]=<factory>)

Features of the Graph object related to nodes.

atomic_numbers

The atomic numbers of the nodes.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

positions

The positions of the nodes.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

forces

The forces on the nodes.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

partial_charges

The partial charges of the nodes.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

hessian

The Hessian matrix (second derivatives of energy w.r.t. node positions.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

features

Any additional node features stored inside a dictionary / PyTree.

Type:

dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]

class mlip.graph.GraphEdges(shifts: float | ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | complex | None = None, displ_fun: ~typing.Callable[[~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex], ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex] | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray] = <factory>)

Features of the Graph object related to edges.

shifts

Shift vectors to compute edge vectors from positions taking into account PBCs. Either this needs to be specified or the displ_fun attribute instead, but not both. If both exist, shifts will be used.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

displ_fun

Alternative to shifts to compute edge vectors from positions taking into account PBCs. The displacement function should be vmapped already, meaning it can take in a position matrix for senders and receivers and output the edge vector matrix. Moreover, the displacement function must be wrapped in jax.tree_util.Partial in order to be compatible with jitting. Note that if the displacement function pathway is applied, stress cannot be calculated as a property. In the future, these two pathways may be unified.

Type:

Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex] | None

features

Any additional edge features stored inside a dictionary / PyTree.

Type:

dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]

class mlip.graph.GraphGlobals(cell: Array | ndarray | bool | number | bool | int | float | complex, weight: Array | ndarray | bool | number | bool | int | float | complex, energy: float | Array | ndarray | bool | number | bool | int | complex | None = None, stress: float | Array | ndarray | bool | number | bool | int | complex | None = None, pressure: float | Array | ndarray | bool | number | bool | int | complex | None = None, charge: float | Array | ndarray | bool | number | bool | int | complex | None = None, dipole_moment: float | Array | ndarray | bool | number | bool | int | complex | None = None, non_corrected_charge: float | Array | ndarray | bool | number | bool | int | complex | None = None, spin_multiplicity: float | Array | ndarray | bool | number | bool | int | complex | None = None, sample_hessian_rows: float | Array | ndarray | bool | number | bool | int | complex | None = None, is_dummy_for_init: float | Array | ndarray | bool | number | bool | int | complex | None = None, dataset_idx: float | Array | ndarray | bool | number | bool | int | complex | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray]=<factory>)

Global features of the Graph object.

cell

The cell definition. Important for PBCs.

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

weight

The weight of the graph, which can be used inside a loss function.

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex

energy

The energy.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

stress

The stress.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

pressure

The pressure.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

charge

The total charge of the graph.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

non_corrected_charge

The total charge of the graph before correction. Required for the total charge term of the loss during training.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

spin_multiplicity

The spin multiplicity of the graph.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

sample_hessian_rows

Indices of force terms to sample for sampled Hessian prediction.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

dataset_idx

An index pointing to which dataset this graph belongs to.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

is_dummy_for_init

Whether this graph is a dummy graph just used for model initialization. By default, this is set to None which means false (but false is not used to allow shape-based evaluation of this field). Will be set to np.array(True) for the dummy initialization graph.

Type:

float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None

features

Any additional global features stored inside a dictionary / PyTree.

Type:

dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]