Source code for diffopt.multigrad.multigrad

"""
"""
import math
from dataclasses import dataclass
from typing import Any, Union

import jax
import numpy as np
from jax import numpy as jnp

from . import util
from .adam import run_adam
from .bfgs import run_bfgs

try:
    from mpi4py import MPI

    COMM = MPI.COMM_WORLD
    Comm = MPI.Comm
    Intracomm = MPI.Intracomm
except ImportError:
    MPI = COMM = None
    Comm = Intracomm = type(None)

try:
    if COMM is not None and COMM.rank:
        raise ImportError("Only show progress bar on the RANK=0 task")
    from tqdm import auto as tqdm
except ImportError:
    tqdm = None


def trange_no_tqdm(n, desc=None):
    return range(n)


def trange_with_tqdm(n, desc=None):
    return tqdm.trange(n, desc=desc)


trange = trange_no_tqdm if tqdm is None else trange_with_tqdm


[docs] def split_subcomms_by_node(comm=None): """ Split comm into sub-comms (grouped by nodes) Parameters ---------- comm : MPI.Comm, optional Specify a sub-communicator to split into sub-sub-communicators Returns ------- subcomm: MPI.Comm The sub-comm that now controls this process num_groups: int The number of groups of subcomms (= number of nodes) group_rank: int The rank of this group (0 <= subcomm_rank < num_subcomms) """ if MPI is None: raise ImportError("MPI is not available. " "Please install mpi4py.") if comm is None: comm = COMM node_name = MPI.Get_processor_name() nodelist = comm.allgather(node_name) unique_nodelist = sorted(list(set(nodelist))) node_number = unique_nodelist.index(node_name) intra_node_id = len([i for i in nodelist[:comm.rank] if i == node_name]) rankinfo = (comm.rank, intra_node_id, node_number) infolist = comm.allgather(rankinfo) sorted_infolist = sorted(infolist, key=lambda x: x[1]) sorted_infolist = sorted(sorted_infolist, key=lambda x: x[2]) subcomm = comm.Split(color=node_number) subcomm.Set_name(f"{comm.name}.{node_number}".replace( "MPI_COMM_WORLD.", "")) # subcomm.Free() # Sometimes this cleanup helps prevent memory leaks return subcomm, len(unique_nodelist), node_number
[docs] def split_subcomms(num_groups=None, ranks_per_group=None, comm=None): """ Split comm into sub-comms (not grouped by nodes) Parameters ---------- num_groups : int, optional Specify the number of evenly divided groups of subcomms ranks_per_group : list[int], optional Specify the number of ranks given to each sub-comm comm : MPI.Comm, optional Specify a sub-communicator to split into sub-sub-communicators Returns ------- subcomm: MPI.Comm The sub-comm that now controls this process num_groups: int The number of groups of subcomms (same as input if not None) group_rank: int The rank of this group (0 <= subcomm_rank < num_subcomms) """ if comm is None: comm = COMM if comm is None: raise ValueError("MPI communicator is not available. " "Please install mpi4py.") main_msg = "Specify either num_subcomms OR ranks_per_subcomm" sumrps_msg = "The sum of ranks_per_subcomm must equal comm.size" nsub_msg = "Cannot create more subcomms than there are ranks" if num_groups is not None: assert ranks_per_group is None, main_msg assert (comm.size >= num_groups), nsub_msg num_groups = int(num_groups) subnames = (np.ones(math.ceil(comm.size / num_groups))[None, :] * np.arange(num_groups)[:, None])[:comm.size] subnames = subnames.ravel().astype(int) else: assert ranks_per_group is not None, main_msg assert sum(ranks_per_group) == comm.size, sumrps_msg num_groups = len(ranks_per_group) subnames = np.repeat(np.arange(num_groups), ranks_per_group) subname = str(np.array_split(subnames, comm.size)[comm.rank][0]) nodelist = comm.allgather(subname) unique_nodelist = sorted(list(set(nodelist))) node_number = unique_nodelist.index(subname) intra_node_id = len([i for i in nodelist[:comm.rank] if i == subname]) rankinfo = (comm.rank, intra_node_id, node_number) infolist = comm.allgather(rankinfo) sorted_infolist = sorted(infolist, key=lambda x: x[1]) sorted_infolist = sorted(sorted_infolist, key=lambda x: x[2]) sub_comm = comm.Split(color=node_number) sub_comm.Set_name(f"{comm.name}.{subname}".replace( "MPI_COMM_WORLD.", "")) # sub_comm.Free() # Sometimes this cleanup helps prevent memory leaks return sub_comm, num_groups, int(subname)
[docs] def reduce_sum(value, root=None, comm=None): """Returns the sum of `value` across all MPI processes Parameters ---------- value : np.ndarray | float | int value input by each MPI process to be summed root : int, optional rank of the process to receive and sum the values, by default None (broadcast result to all ranks) comm : MPI.Intracomm (default = MPI.COMM_WORLD) option to pass a sub-communicator in case the operation is not performed by all MPI ranks Returns ------- np.ndarray | float Sum of values given by each rank of the communicator """ if comm is None: comm = COMM if comm is None: return value return_to_scalar = not hasattr(value, "__len__") value = np.asarray(value) if root is None: # All-to-all sum total = np.empty_like(value) comm.Allreduce(value, total, op=MPI.SUM) else: # All-to-root sum total = np.empty_like(value) comm.Reduce(value, total, op=MPI.SUM, root=root) if return_to_scalar: total = total.tolist() return total
[docs] @dataclass class OnePointModel: """ Allows differentiable one-point calculations to be performed on separate MPI ranks, and automatically sums over each rank controlled by the comm. This is an abstract base class only. The user must personally define the `calc_partial_sumstats_from_params` and `calc_loss_from_sumstats` methods Parameters ---------- aux_data : Any (default=None) Any auxiliary data for easy access within sumstats or loss functions comm : Comm (default=COMM_WORLD) MPI communicator loss_func_has_aux : bool (default=False) If true, `calc_partial_sumstats_from_params(x) -> (y, aux)` and `calc_loss_from_sumstats(y, aux) -> ...` signatures will be assumed sumstats_func_has_aux : bool (default=False) If true, `calc_loss_from_sumstats(...) -> (loss, aux)` signature will be assumed """ aux_data: Any = None comm: Any = None loss_func_has_aux: bool = False sumstats_func_has_aux: bool = False
[docs] def calc_partial_sumstats_from_params(self, params, randkey=None): """Custom method to map parameters to summary statistics""" raise NotImplementedError( "Subclass must implement `calc_partial_sumstats_func_from_params`" )
[docs] def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None, randkey=None): """Custom method to map summary statistics to loss""" raise NotImplementedError( "Subclass must implement `calc_loss_func_from_sumstats`" )
# NOTE: Never jit this method because it uses mpi4py
[docs] def run_simple_grad_descent(self: Any, guess, nsteps=100, learning_rate=0.01, thin=1, progress=True): """ Descend the gradient with a fixed learning rate to optimize parameters, given an initial guess. Stochasticity not allowed. Parameters ---------- guess : array-like The starting parameters. nsteps : int (default=100) The number of steps to take. learning_rate : float (default=0.001) The fixed learning rate. thin : int, optional Return parameters for every `thin` iterations, by default 1. Set `thin=0` to only return final parameters progress : bool, optional Display tqdm progress bar, by default True Returns ------- GradientDescentResult (contains the following attributes): loss : array of loss values returned at each iteration params : array of trial parameters at each iteration aux : array of aux values returned at each iteration """ return util.simple_grad_descent( None, guess=guess, nsteps=nsteps, learning_rate=learning_rate, loss_and_grad_func=self.calc_loss_and_grad_from_params, has_aux=False, thin=thin, progress=progress )
# NOTE: Never jit this method because it uses mpi4py
[docs] def run_adam(self: Any, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, thin=1, progress=True, comm=None): """ Run adam to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity is allowed if randkey is passed. Parameters ---------- guess : array-like The starting parameters. nsteps : int (default=100) The number of steps to take. param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None learning_rate : float (default=0.001) The adam learning rate. randkey : int | PRNG Key (default=None) If given, a new PRNG Key will be generated at each iteration and be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg const_randkey : bool (default=False) By default, randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True thin : int, optional Return parameters for every `thin` iterations, by default 1. Set `thin=0` to only return final parameters progress : bool, optional Display tqdm progress bar, by default True Returns ------- params : jnp.array The trial parameters at each iteration. losses : jnp.array The loss values at each iteration. """ comm = self.comm if comm is None else comm guess = jnp.asarray(guess) if const_randkey: def loss_and_grad_fn(x, _vestigial_data_arg, **kw): return self.calc_loss_and_grad_from_params( x, randkey=init_randkey, **kw) assert randkey is not None, "Must pass randkey if const_randkey" init_randkey = randkey randkey = None else: def loss_and_grad_fn(x, _vestigial_data_arg, **kw): return self.calc_loss_and_grad_from_params(x, **kw) params, losses = run_adam( loss_and_grad_fn, params=guess, data=None, nsteps=nsteps, param_bounds=param_bounds, learning_rate=learning_rate, randkey=randkey, thin=thin, progress=progress, comm=comm ) return params, losses
# NOTE: Never jit this method because it uses mpi4py
[docs] def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, randkey=None, thin=1, progress=True, comm=None): """ Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random key Parameters ---------- guess : array-like The starting parameters. maxsteps : int (default=100) The number of steps to take. param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None randkey : int | PRNG Key (default=None) Since BFGS requires a deterministic function, this key will be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg as a constant at every iteration thin : int, optional Return parameters for every `thin` iterations, by default 1. Set `thin=0` to only return final parameters progress : bool, optional Display tqdm progress bar, by default True Returns ------- params : jnp.array The trial parameters at each iteration. losses : jnp.array The loss values at each iteration. result : OptimizeResult (contains the following attributes): message : str, describes reason of termination success : boolean, True if converged fun : float, minimum loss found x : array of parameters at minimum loss found jac : array of gradient of loss at minimum loss found nfev : int, number of function evaluations nit : int, number of gradient descent iterations """ comm = self.comm if comm is None else comm return run_bfgs( self.calc_loss_and_grad_from_params, guess, maxsteps=maxsteps, param_bounds=param_bounds, randkey=randkey, thin=thin, progress=progress, comm=comm)
[docs] def run_lhs_param_scan(self, xmins, xmaxs, n_dim, num_evaluations, seed=None, randkey=None): """ Compute sumstat and loss values over a Latin Hypercube sample Parameters ---------- xmins : float | array-like Lower bound on each parameter xmaxs : float | array-like Upper bound on each parameter n_dim : int Number of parameters num_evaluations : int Number of Latin Hypercube samples to draw and evaluate seed : int (default=None) Seed to make LHD draws reproducible, randomized by default randkey : PRNGKey | int (default=None) Random key passed to each sumstat and loss evaluation Returns ------- params : array-like Parameters (drawn in Latin Hypercube shape) sumstats : array-like Sumstats evaluated at each draw of parameters losses : array-like Loss evaluated at each draw of parameters """ params = util.latin_hypercube_sampler(xmins, xmaxs, n_dim, num_evaluations, seed=seed) rk = {} if randkey is None else {"randkey": randkey} sumstats = [self.calc_sumstats_from_params(x, **rk) for x in params] losses = [self.calc_loss_from_sumstats(x, **rk) for x in sumstats] return params, np.array(sumstats), np.array(losses)
def __post_init__(self): if self.comm is None: self.comm = COMM # Create auto-diff functions needed for gradient descent self._grad_loss_from_sumstats = jax.grad( self.calc_loss_from_sumstats, has_aux=self.loss_func_has_aux) # sumstats functions # NOTE: Never jit this method because it uses mpi4py (when total=True)
[docs] def calc_sumstats_from_params( self, params, total=True, randkey=None): """Compute summary statistics at given parameters Parameters ---------- params : array-like Model parameters total : bool (default=True) If true (default), sumstats will be summed over all MPI ranks randkey : PRNGKey | int (default=None) If set to a value other than None, the "randkey" kwarg will be passed to user-defined methods Returns ------- array Summary statistics evaluated at given parameters """ kwargs = {} if randkey is None else {"randkey": randkey} result, aux = self.calc_partial_sumstats_from_params( params, **kwargs), None if self.sumstats_func_has_aux: result, aux = result if total: result = jnp.asarray(reduce_sum(result, comm=self.comm)) result = (result, aux) if self.sumstats_func_has_aux else result return result
# loss functions def calc_dloss_dsumstats( self, sumstats, sumstats_aux=None, randkey=None): kwargs = {} if randkey is None else {"randkey": randkey} sumstats = jnp.asarray(sumstats) args = (sumstats, sumstats_aux) if self.sumstats_func_has_aux else ( sumstats,) return self._grad_loss_from_sumstats(*args, **kwargs) # NOTE: Never jit this method because it uses mpi4py
[docs] def calc_loss_from_params( self, params, randkey=None): """Calculate the loss evaluated at a given set of parameters Parameters ---------- params : array-like Model parameters randkey : PRNGKey | int (default=None) If set to a value other than None, the "randkey" kwarg will be passed to user-defined methods Returns ------- float The loss evaluated at the parameters given """ kwargs = {} if randkey is None else {"randkey": randkey} sumstats = self.calc_sumstats_from_params(params, **kwargs) if not self.sumstats_func_has_aux: sumstats = (sumstats,) return self.calc_loss_from_sumstats(*sumstats, **kwargs)
# NOTE: Never jit this method because it uses mpi4py
[docs] def calc_dloss_dparams(self, params, randkey=None): """Calculate the gradient of the loss w.r.t. model parameters given Parameters ---------- params : array-like Model parameters randkey : PRNGKey | int (default=None) If set to a value other than None, the "randkey" kwarg will be passed to user-defined methods Returns ------- array Gradient of the loss with respect to each parameter """ return self._vjp(params, randkey=randkey, include_loss=False)
# NOTE: Never jit this method because it uses mpi4py
[docs] def calc_loss_and_grad_from_params(self, params, randkey=None): """ Calculate the loss and its gradient. This function returns the equivalent of `(calc_loss_from_params(x), calc_dloss_dparams(x))` but it is significantly cheaper than calling them separately Parameters ---------- params : array-like Model parameters randkey : PRNGKey | int (default=None) If set to a value other than None, the "randkey" kwarg will be passed to user-defined methods Returns ------- float The loss evaluated at the parameters given array Gradient of the loss with respect to each parameter """ return self._vjp(params, randkey=randkey, include_loss=True)
# NOTE: Never jit this method because it uses mpi4py def _vjp( self, params, randkey=None, include_loss=True ): kwargs = {} if randkey is None else {"randkey": randkey} params = jnp.asarray(params) def sumstats_func(params): return self.calc_partial_sumstats_from_params(params, **kwargs) # Calculate sumstats AND save VJP func to perform chain rule later vjp_results = jax.vjp( sumstats_func, params, has_aux=self.sumstats_func_has_aux) # type: ignore sumstats, vjp_func = vjp_results[:2] sumstats = jnp.asarray(reduce_sum(sumstats, comm=self.comm)) args = (sumstats, *vjp_results[2:]) # Calculate dloss_dsumstats for chain rule. Should be inexpensive dloss_dsumstats = self.calc_dloss_dsumstats(*args, **kwargs) if self.loss_func_has_aux: dloss_dsumstats = dloss_dsumstats[0] # Use VJP for the chain rule dL/dp[i] = sum(dL/ds[j] * ds[j]/dp[i]) dloss_dparams = jnp.asarray(reduce_sum( vjp_func(dloss_dsumstats)[0], comm=self.comm)) if include_loss: # Return (loss_and_aux, dloss_dparams) return self.calc_loss_from_sumstats(*args, **kwargs), dloss_dparams else: return dloss_dparams def __hash__(self): return hash((self.comm.name, self.calc_loss_from_sumstats)) def __eq__(self, other): return isinstance(other, OnePointGroup) and self is other
[docs] @dataclass class OnePointGroup: """ Allows different OnePointModels to simultaneously perform their calc_loss_and_grad_from_params method. The results are summed. Parameters ---------- models : tuple[OnePointModel] Sequence of models, each providing a loss component to be summed. main_comm : Comm (default=COMM_WORLD) MPI communicator for the entire group (each model should be assigned its own sub-communicator) """ models: Union[tuple[OnePointModel, ...], OnePointModel] main_comm: Any = None def __post_init__(self): if self.main_comm is None: self.main_comm = COMM if isinstance(self.models, OnePointModel): self.models = (self.models,) assert isinstance(self.models[0], OnePointModel) # NOTE: Never jit this method because it uses mpi4py def calc_loss_and_grad_from_params(self, params): loss, grad = [], [] for model in self.models: res = model.calc_loss_and_grad_from_params(params) loss.append(res[0]*0 if model.comm.rank else res[0]) grad.append(res[1]*0 if model.comm.rank else res[1]) loss = jnp.concatenate(jnp.array(self.main_comm.allgather(loss))) grad = jnp.concatenate(jnp.array(self.main_comm.allgather(grad))) return loss.sum(), grad.sum(axis=0) # NOTE: Never jit this method because it uses mpi4py def run_simple_grad_descent(self, guess, nsteps=100, learning_rate=0.01, thin=1, progress=True): return OnePointModel.run_simple_grad_descent( self, guess, nsteps, learning_rate, thin=thin, progress=progress) # NOTE: Never jit this method because it uses mpi4py def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None, thin=1, progress=True): return OnePointModel.run_bfgs( self, guess, maxsteps, param_bounds=param_bounds, randkey=randkey, thin=thin, progress=progress, comm=self.main_comm) # NOTE: Never jit this method because it uses mpi4py def run_adam(self, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, thin=1, progress=True): return OnePointModel.run_adam( self, guess, nsteps, param_bounds, learning_rate, randkey, const_randkey=const_randkey, thin=thin, progress=progress, comm=self.main_comm) def __hash__(self): if isinstance(self.models, OnePointModel): self.models = (self.models,) return hash((self.main_comm.name, self.models[0])) def __eq__(self, other): return isinstance(other, OnePointGroup) and self is other