Graph¶
- class mlip.graph.Graph(nodes: GraphNodes, edges: GraphEdges, globals: GraphGlobals, n_node: Array | ndarray | bool | number | bool | int | float | complex, n_edge: Array | ndarray | bool | number | bool | int | float | complex, senders: Array | ndarray | bool | number | bool | int | float | complex, receivers: Array | ndarray | bool | number | bool | int | float | complex, n_edge_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, senders_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, receivers_long_range: float | Array | ndarray | bool | number | bool | int | complex | None = None, edges_long_range: GraphEdges | None = None)¶
The
Graphclass defining a single graph or a batch of graphs.Modeled after
jraph.GraphsTuple, but with additional methods.- nodes¶
The node features of the graph.
- edges¶
The edge features of the graph.
- globals¶
The global features of the graph.
- n_node¶
The number of nodes in the graph (or a vector if a batch).
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- n_edge¶
The number of edges in the graph (or a vector if a batch).
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- senders¶
The sender indices of the edges of the graph.
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- receivers¶
The receiver indices of the edges of the graph.
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- n_edge_long_range¶
The number of long range edges in the graph.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- senders_long_range¶
The sender indices of the long range edges.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- receivers_long_range¶
The receiver indices of the long range edges.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- edges_long_range¶
Edge information for long-range edges if present.
- Type:
mlip.graph.graph.GraphEdges | None
- classmethod from_chemical_system(chemical_system: ChemicalSystem, graph_cutoff_angstrom: float, long_range_cutoff_angstrom: float | None = None) Self¶
Create a
Graphobject from a chemical system and dataset info.This includes computing the senders/receivers/shifts for the system and otherwise just transferring data 1-to-1 to the graph.
- Parameters:
chemical_system – The chemical system object.
graph_cutoff_angstrom – The graph distance cutoff in Angstrom.
long_range_cutoff_angstrom – The long range distance cutoff in Angstrom. If None, long range interactions are not computed.
- Returns:
The
Graphobject for the given chemical system.
- num_graphs¶
Number of graphs in the (possibly batched, possibly padded) graph.
- replace_nodes(**kwargs) Self¶
Returns the
Graphobject wherenodesattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the nodes dataclass.
- replace_edges(**kwargs) Self¶
Returns the
Graphobject whereedgesattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the edges dataclass.
- replace_globals(**kwargs) Self¶
Returns the
Graphobject whereglobalsattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the globals dataclass.
- update_node_features(**kwargs) Self¶
Returns the
Graphobject wherenodesattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the nodes dataclass.
- update_edge_features(**kwargs) Self¶
Returns the
Graphobject whereedgesattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the nodes dataclass.
- update_global_features(**kwargs) Self¶
Returns the
Graphobject whereglobalsattribute are replaced.Keyword arguments are forwarded to the
.replace()call on the globals dataclass.
- request_full_hessian() Self¶
Returns a graph that has
sample_hessian_rows=np.array(True)in the globals.Required for inference pipelines.
- node_mask() Array¶
Evaluates the node padding mask array for the graph.
Truerefers to a real node, whileFalserefers to a dummy node in the (batched) graph.- Returns:
The node padding mask.
- graph_mask() Array¶
Evaluates the graph padding mask array for the batched graph.
Truerefers to a real graph, whileFalserefers to a dummy graph in the batched graph.- Returns:
The graph padding mask.
- to_prediction() Prediction¶
Creates a
Predictionobject from the current graph, which contains all the properties of the graph that are also part of a prediction.- Returns:
The prediction.
- edge_vectors(use_np: bool = False) Array | ndarray | bool | number | bool | int | float | complex¶
Compute the relative edge vectors from senders to receivers.
We use
displ_funif available, otherwise edge vectors are computed directly using thepositionsandshifts. In the case of PBCs, sender nodes are translated from the unit cell to the receiver’s nearest neighbouring cell:# If `displ_fun` is None, this method returns: vectors = positions[receivers] - (positions[senders] - shifts @ cell) # Equivalent to this line from the ASE docs: D = positions[j] - positions[i] + S.dot(cell)
- Parameters:
use_np – Whether to use
numpyorjax.numpyfor the computation. Default isFalse, which meansjax.numpyis used.- Returns:
The relative edge vectors, labelled
Dby ASE.
- long_range_edge_vectors(use_np: bool = False) Array | ndarray | bool | number | bool | int | float | complex¶
Compute the relative long-range edge vectors from senders to receivers.
Mirrors
self.edge_vectors()for the long range graph.- Parameters:
use_np – Whether to use
numpyorjax.numpyfor the computation. Default isFalse, which meansjax.numpyis used.- Returns:
The relative long-range edge vectors.
- class mlip.graph.GraphNodes(atomic_numbers: float | Array | ndarray | bool | number | bool | int | complex | None = None, positions: float | Array | ndarray | bool | number | bool | int | complex | None = None, forces: float | Array | ndarray | bool | number | bool | int | complex | None = None, partial_charges: float | Array | ndarray | bool | number | bool | int | complex | None = None, hessian: float | Array | ndarray | bool | number | bool | int | complex | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray]=<factory>)¶
Features of the
Graphobject related to nodes.- atomic_numbers¶
The atomic numbers of the nodes.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- positions¶
The positions of the nodes.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- forces¶
The forces on the nodes.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- partial_charges¶
The partial charges of the nodes.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- hessian¶
The Hessian matrix (second derivatives of energy w.r.t. node positions.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- features¶
Any additional node features stored inside a dictionary / PyTree.
- Type:
dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]
- class mlip.graph.GraphEdges(shifts: float | ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | complex | None = None, displ_fun: ~typing.Callable[[~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex], ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex] | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray] = <factory>)¶
Features of the
Graphobject related to edges.- shifts¶
Shift vectors to compute edge vectors from positions taking into account PBCs. Either this needs to be specified or the
displ_funattribute instead, but not both. If both exist, shifts will be used.- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- displ_fun¶
Alternative to
shiftsto compute edge vectors from positions taking into account PBCs. The displacement function should be vmapped already, meaning it can take in a position matrix for senders and receivers and output the edge vector matrix. Moreover, the displacement function must be wrapped injax.tree_util.Partialin order to be compatible with jitting. Note that if the displacement function pathway is applied, stress cannot be calculated as a property. In the future, these two pathways may be unified.- Type:
Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex] | None
- features¶
Any additional edge features stored inside a dictionary / PyTree.
- Type:
dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]
- class mlip.graph.GraphGlobals(cell: Array | ndarray | bool | number | bool | int | float | complex, weight: Array | ndarray | bool | number | bool | int | float | complex, energy: float | Array | ndarray | bool | number | bool | int | complex | None = None, stress: float | Array | ndarray | bool | number | bool | int | complex | None = None, pressure: float | Array | ndarray | bool | number | bool | int | complex | None = None, charge: float | Array | ndarray | bool | number | bool | int | complex | None = None, dipole_moment: float | Array | ndarray | bool | number | bool | int | complex | None = None, non_corrected_charge: float | Array | ndarray | bool | number | bool | int | complex | None = None, spin_multiplicity: float | Array | ndarray | bool | number | bool | int | complex | None = None, sample_hessian_rows: float | Array | ndarray | bool | number | bool | int | complex | None = None, is_dummy_for_init: float | Array | ndarray | bool | number | bool | int | complex | None = None, dataset_idx: float | Array | ndarray | bool | number | bool | int | complex | None = None, features: dict[str, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex | ~e3nn_jax._src.irreps_array.IrrepsArray]=<factory>)¶
Global features of the
Graphobject.- cell¶
The cell definition. Important for PBCs.
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- weight¶
The weight of the graph, which can be used inside a loss function.
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex
- energy¶
The energy.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- stress¶
The stress.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- pressure¶
The pressure.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- charge¶
The total charge of the graph.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- non_corrected_charge¶
The total charge of the graph before correction. Required for the total charge term of the loss during training.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- spin_multiplicity¶
The spin multiplicity of the graph.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- sample_hessian_rows¶
Indices of force terms to sample for sampled Hessian prediction.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- dataset_idx¶
An index pointing to which dataset this graph belongs to.
- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- is_dummy_for_init¶
Whether this graph is a dummy graph just used for model initialization. By default, this is set to
Nonewhich means false (but false is not used to allow shape-based evaluation of this field). Will be set tonp.array(True)for the dummy initialization graph.- Type:
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | complex | None
- features¶
Any additional global features stored inside a dictionary / PyTree.
- Type:
dict[str, jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | e3nn_jax._src.irreps_array.IrrepsArray]