Hessian Predictor

class mlip.models.predictors.hessian_predictor.HessianPredictor(mlip_network: ~mlip.models.mlip_network.MLIPNetwork, required_properties: ~mlip.typing.properties.Properties, energy_head: ~typing.Callable[[~mlip.graph.graph.Graph], ~jax.Array] | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Subclass of ConservativePredictor used to predict either the full energy Hessian matrix of a system, or a subsample of Hessian rows, depending on the graph.sample_hessian_rows attribute which is one of:

  • an array of indices of shape (G, R), used to subsample the full Hessian matrix. An array of shape (n, R, 3) is then returned.

  • array(True). In this case, the full Hessian matrix of shape (N+1, 3, N+1, 3) is returned.

  • None, in which case no Hessian is returned. This is useful to skip the additional AD pass, e.g. in mixed labels training.

Where N is the number of total graph nodes including padding nodes, n number of real graph nodes, and R number of Hessian rows`.

__call__(graph: Graph) Graph

Evaluates the Hessian predictor on a given graph.

Computes the required properties including the energy Hessian, and updates the input graph with these quantities. If Hessian is not required, falls back to evaluating the parent conservative predictor.

Parameters:

graph – The input graph.

Returns:

An updated graph containing all predicted properties.

compute_sum_forces_subsample(positions: Array, graph: Graph) tuple[Array, Graph]

Return (sum(F[sample_rows]), graph) pair for downstream auto diff. The auxiliary Graph object can be forwarded by downstream methods, while the caller may differentiate through the subsampled force components F[sample_rows] to compute Hessian rows.