netket_foundation.VMC_NG#
- class netket_foundation.VMC_NG(hamiltonian, optimizer, *, diag_shift, proj_reg=None, momentum=None, linear_solver=<function cholesky>, variational_state=None, chunk_size_bwd=None, mode=None, use_ntk=False, on_the_fly=False)[source]#
Energy minimization using Variational Monte Carlo (VMC) and Stochastic Reconfiguration (SR) with or without its kernel formulation. The two approaches lead to exactly the same parameter updates. In the kernel SR framework, the updates of the parameters can be written as:
\[\delta \theta = \tau X(X^TX + \lambda \mathbb{I}_{2M})^{-1} f,\]where \(X \in R^{P \times 2M}\) is the concatenation of the real and imaginary part of the centered Jacobian, with P the number of parameters and M the number of samples. The vector f is the concatenation of the real and imaginary part of the centered local energy. Note that, to compute the updates, it is sufficient to invert an \(M\times M\) matrix instead of a \(P\times P\) one. As a consequence, this formulation is useful in the typical deep learning regime where \(P \gg M\).
See R.Rende, L.L.Viteritti, L.Bardone, F.Becca and S.Goldt for a detailed description of the derivation. A similar result can be obtained by minimizing the Fubini-Study distance with a specific constrain, see A.Chen and M.Heyl for details.
When momentum is used, this driver implements the SPRING optimizer in G.Goldshlager, N.Abrahamsen and L.Lin to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.
- __init__(hamiltonian, optimizer, *, diag_shift, proj_reg=None, momentum=None, linear_solver=<function cholesky>, variational_state=None, chunk_size_bwd=None, mode=None, use_ntk=False, on_the_fly=False)[source]#
Initialize the driver.
- Parameters:
hamiltonian (
AbstractOperator) – The Hamiltonian of the system.optimizer (
Any) – Determines how optimization steps are performed given the bare energy gradient.diag_shift (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]],Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]]]) – The diagonal shift of the curvature matrix.proj_reg (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]],Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]],None]) – Weight before the matrix 1/N_samples \bm{1} \bm{1}^T used to regularize the linear solver in SPRING.momentum (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]],Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]],None]) – Momentum used to accumulate updates in SPRING.linear_solver (
Callable[[Union[ndarray,Array],Union[ndarray,Array]],Union[ndarray,Array]]) – Callable to solve the linear problem associated to the updates of the parameters.mode (
JacobianMode|None) – The mode used to compute the jacobian or vjp of the variational state. Can be ‘real’ or ‘complex’ (defaults to the dtype of the output of the model). real can be used for real wavefunctions with a sign to further reduce the computational costs.on_the_fly (
bool|None) – Whether to compute the QGT or NTK matrix without evaluating the full jacobian. Defaults to True. This ususally lowers the memory requirement and is necessary for large calculations.use_ntk (
bool) – Whether to use the NTK instead of the QGT for the computation of the updates.variational_state (
MCState) – Thenetket.vqs.MCStateto be optimised. Other variational states are not supported.chunk_size_bwd (
int|None) – The chunk size to use for the backward pass (jacobian or vjp evaluation).collect_quadratic_model – Whether to collect the quadratic model. The quantities collected are the linear and quadratic term in the approximation of the loss function. They are stored in the info dictionary of the driver.
- Returns:
The new parameters, the old updates, and the info dictionary.
Methods
__init__(hamiltonian, optimizer, *, diag_shift)Initialize the driver.
compute_loss_and_update()estimate(observables[, fullsum])Return MCMC statistics for the expectation value of observables in the current state of the driver.
replace(**kwargs)Replace the values of the fields of the object with the values of the keyword arguments.
reset()reset_step([hard])Resets the state of the driver at the beginning of a new step.
run(n_iter[, out, obs, step_size, ...])Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.
update_parameters(dp)Updates the parameters of the machine using the optimizer in this driver.
Attributes
chunk_size_bwdChunk size for backward-mode differentiation.
diag_shiftThe diagonal shift \(\lambda\) in the curvature matrix.
infoPyTree to pass on information from the solver,e.g, the quadratic model.
modeThe mode used to compute the jacobian of the variational state.
momentumFlag specifying whether to use momentum in the optimisation.
on_the_flyWhether to use a lazy implementation of th NTK or QGT which does not concretize the jacobian.
optimizerThe optimizer used to update the parameters at every iteration.
proj_regstateReturns the machine that is optimized by this driver.
step_countReturns a monotonic integer labelling all the steps performed by this driver.
update_fnReturns the function to compute the NGD update based on the evaluation mode.
use_ntkWhether to use the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT) to compute the update.