Source code for diffopt.kdescent.descent

from functools import partial

import jax.numpy as jnp
import jax.random
import numpy as np
import scipy.optimize
import tqdm.auto as tqdm
from jax.example_libraries import optimizers as jax_opt

from . import keygen


[docs] def adam(lossfunc, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=1, const_randkey=False, thin=1, progress=True, **other_kwargs): """ Perform gradient descent Parameters ---------- lossfunc : callable Function to be minimized via gradient descent. Must be compatible with jax.jit and jax.grad. Must have signature f(params, **other_kwargs) guess : array-like The starting parameters. nsteps : int, optional Number of gradient descent iterations to perform, by default 100 param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None learning_rate : float, optional Initial Adam learning rate, by default 0.05 randkey : int, optional Random seed or key, by default 1. If not None, lossfunc must accept the "randkey" keyword argument, e.g. `lossfunc(params, randkey=key)` const_randkey : bool, optional By default (False), randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True thin : int, optional Return parameters for every `thin` iterations, by default 1. Set `thin=0` to only return final parameters progress : bool, optional Display tqdm progress bar, by default True Returns ------- params : jnp.array The trial parameters at each iteration. losses : jnp.array The loss values at each iteration. """ if param_bounds is None: return adam_unbounded( lossfunc, guess, nsteps, learning_rate, randkey, const_randkey, thin, progress, **other_kwargs) assert len(guess) == len(param_bounds) if hasattr(param_bounds, "tolist"): param_bounds = param_bounds.tolist() param_bounds = [b if b is None else tuple(b) for b in param_bounds] def ulossfunc(uparams, *args, **kwargs): params = apply_inverse_transforms(uparams, param_bounds) return lossfunc(params, *args, **kwargs) init_uparams = apply_transforms(guess, param_bounds) uparams, loss = adam_unbounded( ulossfunc, init_uparams, nsteps, learning_rate, randkey, const_randkey, thin, progress, **other_kwargs) params = apply_inverse_transforms(uparams.T, param_bounds).T return params, loss
def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, randkey=1, const_randkey=False, thin=1, progress=True, **other_kwargs): kwargs = {**other_kwargs} if randkey is not None: randkey = keygen.init_randkey(randkey) randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i if const_randkey: randkey = None loss_and_grad = jax.jit(jax.value_and_grad(lossfunc)) opt_init, opt_update, get_params = jax_opt.adam(learning_rate) opt_state = opt_init(guess) params_i = guess loss = [] params = [] thindiv = thin if thin else nsteps for i in tqdm.trange(nsteps + 1, disable=not progress, desc="Adam Gradient Descent Progress"): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i loss_i, grad = loss_and_grad(params_i, **kwargs) if (i - 1) % thindiv == 0 or not len(params): loss.append(loss_i) params.append(params_i) else: loss[-1] = loss_i params[-1] = params_i if i < nsteps: opt_state = opt_update(i, grad, opt_state) params_i = get_params(opt_state) if not thin: params = params[-1] loss = loss[-1] return jnp.array(params), jnp.array(loss)
[docs] def bfgs(lossfunc, guess, maxsteps=100, param_bounds=None, randkey=None, thin=1, progress=True): """ Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random key Parameters ---------- lossfunc : callable Function to be minimized via gradient descent. Must be compatible with jax.jit and jax.grad. Must have signature f(params, **other_kwargs) guess : array-like The starting parameters. maxsteps : int, optional The maximum number of steps to take, by default 100. param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None randkey : int | PRNG Key, optional Since BFGS requires a deterministic function, this key will be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg as a constant at every iteration, by default None thin : int, optional Return parameters for every `thin` iterations, by default 1. Set `thin=0` to only return final parameters progress : bool, optional Display tqdm progress bar, by default True Returns ------- params : jnp.array The trial parameters at each iteration. losses : jnp.array The loss values at each iteration. result : OptimizeResult (contains the following attributes): message : str, describes reason of termination success : boolean, True if converged fun : float, minimum loss found x : array of parameters at minimum loss found jac : array of gradient of loss at minimum loss found nfev : int, number of function evaluations nit : int, number of gradient descent iterations """ kwargs = {} if randkey is not None: randkey = keygen.init_randkey(randkey) kwargs["randkey"] = randkey pbar = tqdm.trange(maxsteps, desc="BFGS Gradient Descent Progress", disable=not progress) params = [] loss = [] step = [-1] thindiv = thin if thin else maxsteps * len(guess) def callback(intermediate_result): if step[0] % thindiv == 0 or not len(params): params.append(intermediate_result.x) loss.append(intermediate_result.fun) else: params[-1] = intermediate_result.x loss[-1] = intermediate_result.fun step[0] += 1 pbar.update() loss_and_grad_fn = jax.value_and_grad( lambda x: lossfunc(x, **kwargs)) result = scipy.optimize.minimize( loss_and_grad_fn, x0=guess, method="L-BFGS-B", jac=True, options=dict(maxiter=maxsteps), callback=callback, bounds=param_bounds) if not thin: params = params[-1] loss = loss[-1] pbar.close() return jnp.array(params), jnp.array(loss), result
def apply_transforms(params, bounds): return jnp.array([transform(param, bound) for param, bound in zip(params, bounds)]) def apply_inverse_transforms(uparams, bounds): return jnp.array([inverse_transform(uparam, bound) for uparam, bound in zip(uparams, bounds)]) @partial(jax.jit, static_argnums=[1]) def transform(param, bounds): """Transform param into unbound param""" if bounds is None: return param low, high = bounds low_is_finite = low is not None and np.isfinite(low) high_is_finite = high is not None and np.isfinite(high) if low_is_finite and high_is_finite: mid = (high + low) / 2.0 scale = (high - low) / jnp.pi return scale * jnp.tan((param - mid) / scale) elif low_is_finite: return param - low + 1.0 / (low - param) elif high_is_finite: return param - high + 1.0 / (high - param) else: return param @partial(jax.jit, static_argnums=[1]) def inverse_transform(uparam, bounds): """Transform unbound param back into param""" if bounds is None: return uparam low, high = bounds low_is_finite = low is not None and np.isfinite(low) high_is_finite = high is not None and np.isfinite(high) if low_is_finite and high_is_finite: mid = (high + low) / 2.0 scale = (high - low) / jnp.pi return mid + scale * jnp.arctan(uparam / scale) elif low_is_finite: return 0.5 * (2.0 * low + uparam + jnp.sqrt(uparam**2 + 4)) elif high_is_finite: return 0.5 * (2.0 * high + uparam - jnp.sqrt(uparam**2 + 4)) else: return uparam