Hessian Predictor¶
- class mlip.models.predictors.hessian_predictor.HessianPredictor(mlip_network: ~mlip.models.mlip_network.MLIPNetwork, required_properties: ~mlip.typing.properties.Properties, energy_head: ~typing.Callable[[~mlip.graph.graph.Graph], ~jax.Array] | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)¶
Subclass of
ConservativePredictorused to predict either the full energy Hessian matrix of a system, or a subsample of Hessian rows, depending on thegraph.sample_hessian_rowsattribute which is one of:an array of indices of shape
(G, R), used to subsample the full Hessian matrix. An array of shape(n, R, 3)is then returned.array(True). In this case, the full Hessian matrix of shape(N+1, 3, N+1, 3)is returned.None, in which case no Hessian is returned. This is useful to skip the additional AD pass, e.g. in mixed labels training.
Where
Nis the number of total graph nodes including padding nodes,nnumber of real graph nodes, andRnumber of Hessian rows`.- __call__(graph: Graph) Graph¶
Evaluates the Hessian predictor on a given graph.
Computes the required properties including the energy Hessian, and updates the input graph with these quantities. If Hessian is not required, falls back to evaluating the parent conservative predictor.
- Parameters:
graph – The input graph.
- Returns:
An updated graph containing all predicted properties.
- compute_sum_forces_subsample(positions: Array, graph: Graph) tuple[Array, Graph]¶
Return
(sum(F[sample_rows]), graph)pair for downstream auto diff. The auxiliaryGraphobject can be forwarded by downstream methods, while the caller may differentiate through the subsampled force componentsF[sample_rows]to compute Hessian rows.