netket_foundation.VMC_SR#
- class netket_foundation.VMC_SR[source]#
Bases:
VMC_SREnergy minimization using Variational Monte Carlo (VMC) and Stochastic Reconfiguration (SR) / Natural Gradient Descent, specialized for the foundational training scheme.
This driver tracks
netket.driver.VMC_SR, and we refer to its documentation for a detailed description of the method, the available formulations (standard vs. kernel/minSR), the matrix-inversion solvers, and the momentum/SPRING accelerator. All of those options behave as documented there.The difference is that this driver computes the SR/NGD formulas a bit differently in order to better make use of the foundation training scheme: a single variational state holds several replicas (one per point in
ParameterSpace, e.g. a range of Hamiltonian parameters) and they are all trained simultaneously. The Jacobian and the local energies carry an extra replica dimension, which is handled explicitly so that the natural-gradient updates are computed per replica rather than mixing samples across different physical points. The progress bar and logged loss therefore report per-replica energy statistics (seelog_replica_stats) rather than a single-state energy.For the underlying SR/NGD derivation and references, see
netket.driver.VMC_SR.- __init__(hamiltonian, optimizer, *, diag_shift, proj_reg=None, momentum=None, linear_solver=<function cholesky_with_fallback>, variational_state=None, chunk_size_bwd=None, mode=None, use_ntk=False, on_the_fly=False, log_replica_stats=False)[source]#
Initialize the driver.
- Parameters:
hamiltonian (
AbstractOperator) – The Hamiltonian of the system.optimizer (
Any) – Determines how optimization steps are performed given the bare energy gradient.diag_shift (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]]]) – The diagonal shift of the curvature matrix.proj_reg (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]],None]) – Weight before the matrix 1/N_samples \bm{1} \bm{1}^T used to regularize the linear solver in SPRING.momentum (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]],None]) – Momentum used to accumulate updates in SPRING.linear_solver (
Callable[[Union[ndarray,Array],Union[ndarray,Array]],Union[ndarray,Array]]) – Callable to solve the linear problem associated to the updates of the parameters.mode (
JacobianMode|None) – The mode used to compute the jacobian or vjp of the variational state. Can be ‘real’ or ‘complex’ (defaults to the dtype of the output of the model). real can be used for real wavefunctions with a sign to further reduce the computational costs.on_the_fly (
bool|None) – Whether to compute the NTK matrix without evaluating the full jacobian. This usually lowers the memory requirement and is necessary for large calculations. Only supported together withuse_ntk=True(importance-sampling weights / pdf are not yet supported in the on-the-fly path).use_ntk (
bool) – Whether to use the NTK instead of the QGT for the computation of the updates.variational_state (
MCState) – Thenetket.vqs.MCStateto be optimised. Other variational states are not supported.chunk_size_bwd (
int|None) – The chunk size to use for the backward pass (jacobian or vjp evaluation).collect_quadratic_model – Whether to collect the quadratic model. The quantities collected are the linear and quadratic term in the approximation of the loss function. They are stored in the info dictionary of the driver.
log_replica_stats (bool)
- Returns:
The new parameters, the old updates, and the info dictionary.
- Attributes
- log_replica_stats: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': False, 'ignore': False, 'serialize': False, 'cache': False, 'sharded': False}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- Methods
- compute_loss_and_update()[source]#
Performs a step of the optimization driver, returning the PyTree of the gradients that will be optimized.
Concrete drivers must override this method.
Note
When implementing this function on a subclass, you must return the gradient which must match the pytree structure of the parameters of the variational state.
The gradient will then be passed on to the optimizer in order to update the parameters.
Moreover, if you are minimising a loss function you must set the field self._loss_stats with the current value of the loss function.
This will be logged to any logger during optimisation.
- Returns:
the update for the weights.
- estimate(observables, fullsum=False)[source]#
Return MCMC statistics for the expectation value of observables in the current state of the driver.
- Parameters:
observables – A pytree of operators for which statistics should be computed.
fullsum (bool)
- Returns:
A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- reset()[source]#
Deprecated since version 3.22: Use
reset_step()to reset the sampler state at the beginning of a step. Note that the oldreset()also resetstep_countto 0; this behaviour is no longer supported.
- reset_step(hard=False)[source]#
Resets the state of the driver at the beginning of a new step.
This method is called at the beginning of every step in the optimization.
- Parameters:
hard (
bool) – If True, the reset is a hard reset, resulting in a complete resampling even if resample_fractionNone. (is not)
- run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=None, timeit=False, _graceful_keyboard_interrupt=True)[source]#
Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.
It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.
Callbacks are functions that follow this signature:
def callback(step, log_data, driver) -> bool: ... return True/False
If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.
Alternatively,
AbstractCallbacksubclasses can be used to hook into more stages of the loop. To stop the optimisation early from any callback hook, raiseStopRun: the driver will catch it, finalise all callbacks via theiron_run_endmethod, and return normally without propagating the exception.Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a
nk.logging.JsonLog. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.- Parameters:
n_iter (
int) – the total number of iterations to be performed during this run.out (
Iterable[AbstractLog] |None) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.obs (
dict[str,AbstractObservable] |None) – An iterable containing all observables that should be computedstep_size (
int) – Every how many steps should observables be logged to disk (default=1)callback (
Union[Callable[[int,dict,AbstractDriver],bool],AbstractCallback,None]) – Callable or list of callable callback functions to stop training given a conditionshow_progress (
bool) – If true displays a progress bar (default=True)save_params_every (
int) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)write_every (
int) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)timeit (
bool) – If True, provide timing information._graceful_keyboard_interrupt (
bool) – (Internal flag, defaults to True) If True, the driver will gracefully handle a KeyboardInterrupt, usually arising from doing ctrl-C, returning the current state of the simulation. If False, the KeyboardInterrupt will be raised as usual. This only has an effect when running in interactive mode.