Source code for diffopt.multiswarm.pso_update

"""Implementation of PSO algorithm described in arXiv:1108.5600 & arXiv:1310.7034"""  # noqa
from time import time

import jax
import numpy as np
import tqdm.auto as tqdm
from jax import numpy as jnp
from jax import random as jran
from scipy.stats import qmc

from .mpi_utils import split_subcomms

try:
    from mpi4py.MPI import COMM_WORLD
except ImportError:
    COMM_WORLD = None

# INERTIAL_WEIGHT = (0.5 / np.log(2))
# ACC_CONST = (0.5 + np.log(2))
INERTIAL_WEIGHT = 1.0
COGNITIVE_WEIGHT = 0.21
SOCIAL_WEIGHT = 0.07
VMAX_FRAC = 0.4


[docs] class ParticleSwarm:
[docs] def __init__(self, nparticles, ndim, xlow, xhigh, seed=0, inertial_weight=INERTIAL_WEIGHT, cognitive_weight=COGNITIVE_WEIGHT, social_weight=SOCIAL_WEIGHT, vmax_frac=VMAX_FRAC, ranks_per_particle=None, comm=None): """ Initialize particles and MPI communicators to be used for PSO Parameters ---------- nparticles : int Number of particles (~100+ recommended) ndim : int Dimensionality (i.e. number of model parameters to fit) xlow : int | Array[int] Lower bounds on each parameter xhigh : int | Array[int] Upper bounds on each parameter seed : int | PRNGKey, optional Seed for all pseudo-randomness, by default 0 inertial_weight : float, optional Retain this fraction of the velocity from previous timestep, by default 1.0 cognitive_weight : float, optional Weight pulling particles towards their personal best location ever found, by default 0.21 social_weight : float, optional Weight pulling particles towards the global best location ever found, recommended ~1/3 of `cognitive_weight`, by default 0.07 vmax_frac : float, optional Maximum velocity particles are allowed to travel, as a fraction of their box width per dimension, by default 0.4 ranks_per_particle : int, optional Set this to manually control intra-particle parallelization, even if there are not enough ranks for nparticles * ranks_per_particle. By default (None), inter-particle parallelization is prioritized comm : MPI.Comm, optional MPI Communicator, by default COMM_WORLD """ if comm is None: comm = COMM_WORLD randkey = init_randkey(seed) subcomm, particles_on_this_rank = get_subcomm( nparticles, ranks_per_particle, comm=comm, return_particles_on_this_rank=True) num_particles_on_this_rank = len(particles_on_this_rank) init_key, *particle_keys = jran.split( randkey, nparticles + 1) particle_keys = [particle_keys[i] for i in particles_on_this_rank] init_cond = get_lhs_initial_conditions( nparticles, ndim, xlo=xlow, xhi=xhigh, vmax_frac=vmax_frac, ran_key=init_key) xmin, xmax, x_init, v_init = init_cond self.nparticles = nparticles self.ndim = ndim self.xlow, self.xhigh = xlow, xhigh self.comm = comm self.particles_on_this_rank = particles_on_this_rank self.num_particles_on_this_rank = num_particles_on_this_rank self.particle_keys = particle_keys self.subcomm = subcomm self.xmin, self.xmax = xmin, xmax self.x_init, self.v_init = x_init, v_init self.inertial_weight = inertial_weight self.cognitive_weight = cognitive_weight self.social_weight = social_weight self.vmax_frac = vmax_frac
[docs] def run_pso(self, lossfunc, nsteps=100, progress=True, keep_init_random_state=False): """ Run particle swarm optimization (PSO) Parameters ---------- lossfunc : callable The function we want to find the global minimum of. To be called with signature `lossfunc(x)` where x is an array of shape `(ndim,)` nsteps : int, optional Number of time step iterations, by default 100 progress : bool, optional Display tqdm progress bar, by default True keep_init_random_state : bool, optional Set True to be able to rerun an identical run, or False (default) to continue a run by manually setting swarm.x_init and swarm.v_init Returns ------- Results dictionary with the following keys: "swarm_x_history" : np.ndarray of shape (nsteps, nparticles, ndim) Position of all particles (trial params) at each time step "swarm_v_history": np.ndarray of shape (nsteps, nparticles, ndim) Velocity of all particles at each time step "swarm_loss_history": np.ndarray of shape (nsteps, nparticles) Loss of all particles at each time step "runtime": float Time in seconds, as measured on each rank, to perform PSO """ if keep_init_random_state: particle_keys = self.particle_keys.copy() else: particle_keys = self.particle_keys x = [self.x_init[pr] for pr in self.particles_on_this_rank] v = [self.v_init[pr] for pr in self.particles_on_this_rank] loc_loss_best = [lossfunc(xi) for xi in x] loc_x_best = [np.copy(xi) for xi in x] swarm_x_best, swarm_loss_best = self._get_global_best(x, loc_loss_best) loc_x_history = [[] for _ in range(self.num_particles_on_this_rank)] loc_v_history = [[] for _ in range(self.num_particles_on_this_rank)] loc_loss_history = [[] for _ in range(self.num_particles_on_this_rank)] start = time() def trange(x, disable=False): if self.comm.rank: return range(x) else: return tqdm.trange(x, desc="PSO Progress", disable=disable) for _ in trange(nsteps, disable=not progress): istep_loss = [None for _ in range(self.num_particles_on_this_rank)] for ip in range(self.num_particles_on_this_rank): update_key = jran.split(particle_keys[ip], 1)[0] particle_keys[ip] = update_key x[ip], v[ip] = update_particle( update_key, x[ip], v[ip], self.xmin, self.xmax, loc_x_best[ip], swarm_x_best, self.inertial_weight, self.cognitive_weight, self.social_weight, self.vmax_frac ) istep_loss[ip] = lossfunc(x[ip]) istep_x_best, istep_loss_best = self._get_global_best( x, istep_loss) for ip in range(self.num_particles_on_this_rank): if istep_loss_best <= swarm_loss_best: swarm_loss_best = istep_loss_best swarm_x_best = istep_x_best if istep_loss <= loc_loss_best: loc_loss_best = istep_loss loc_x_best = x loc_x_history[ip].append(x[ip]) loc_v_history[ip].append(v[ip]) loc_loss_history[ip].append(istep_loss[ip]) # anneal = annealing_frac * self.inertial_weight # self.inertial_weight -= anneal # self.social_weight += anneal end = time() runtime = end - start if self.subcomm is not None and self.subcomm.rank > 0: # Only concatenate particles from the ROOT of each subcomm loc_x_history = np.zeros(shape=(0, *np.shape(loc_x_history[0]))) loc_v_history = np.zeros(shape=(0, *np.shape(loc_v_history[0]))) loc_loss_history = np.zeros( shape=(0, *np.shape(loc_loss_history[0]))) swarm_x_history = np.concatenate(self.comm.allgather( loc_x_history), axis=0).swapaxes(0, 1) swarm_v_history = np.concatenate(self.comm.allgather( loc_v_history), axis=0).swapaxes(0, 1) swarm_loss_history = np.concatenate(self.comm.allgather( loc_loss_history), axis=0).swapaxes(0, 1) return { "swarm_x_history": swarm_x_history, "swarm_v_history": swarm_v_history, "swarm_loss_history": swarm_loss_history, "runtime": runtime }
def _get_global_best(self, x, loss): if self.subcomm is not None and self.subcomm.rank > 0: # Only concatenate particles from the ROOT of each subcomm x = np.zeros(shape=(0, *np.shape(x[0]))) loss = np.zeros(shape=(0, *np.shape(loss[0]))) all_x = np.concatenate(self.comm.allgather(x)) all_loss = np.concatenate(self.comm.allgather(loss)) best_particle = np.argmin(all_loss) best_x = all_x[best_particle, :] best_loss = all_loss[best_particle] return best_x, best_loss
[docs] def get_subcomm(nparticles, ranks_per_particle=None, comm=None, return_particles_on_this_rank=False): """ Initialize MPI communicators to be used for PSO Parameters ---------- nparticles : int Number of particles ranks_per_particle : int, optional Set this to manually control intra-particle parallelization, even if there are not enough ranks for nparticles * ranks_per_particle. By default (None), inter-particle parallelization is prioritized comm : MPI.Comm, optional MPI Communicator, by default COMM_WORLD return_particles_on_this_rank : bool, optional If true, return tuple (subcomm, particles_on_this_rank). By default, only subcomm is returned Returns ------- subcomm : MPI.Comm This rank's subcommunicator, which can only talk to its "group" particles_on_this_rank : list If `return_particles_on_this_rank=True` this list will be returned, specifying the indices of particles this group is responsible for """ if comm is None: comm = COMM_WORLD if comm is None: raise ValueError("MPI communicator is not available. " "Please install mpi4py.") rank, nranks = comm.Get_rank(), comm.Get_size() if ranks_per_particle is not None: # Set this to manually control intra-particle parallelization vs # inter-particle parallelization, even when there are not enough # ranks for nparticles * ranks_per_particle. By default, # inter-particle parallelization is prioritized. num_groups = comm.size / ranks_per_particle msg = "comm.size must be a multiple of ranks_per_particle" assert not num_groups % 1, msg num_groups = int(num_groups) subcomm, _, group_rank = split_subcomms(num_groups, comm=comm) particles_on_this_rank = [x for x in np.array_split( np.arange(nparticles), num_groups)[group_rank]] elif nparticles > nranks: particles_on_this_rank = [x for x in np.array_split( np.arange(nparticles), nranks)[rank]] subcomm = None else: subcomm, _, particles_on_this_rank = split_subcomms( nparticles, comm=comm) particles_on_this_rank = [particles_on_this_rank] if return_particles_on_this_rank: return subcomm, particles_on_this_rank else: return subcomm
[docs] def get_best_loss_and_params(loss_history, params_history): """ Return the best loss and its corresponding parameters from the full results arrays returned by run_pso() Parameters ---------- loss_history : Array[float] of shape (nsteps, nparticles) Loss of all particles at each time, given by "swarm_loss_history" params_history : Array[float] of shape (nsteps, nparticles, ndim) Position of all particles at each time, given by "swarm_x_history" Returns ------- float Minimum loss value nd.ndarray[float] Parameters that produced the minimum loss """ loss_history = np.ravel(loss_history) params_history = np.reshape(params_history, (*loss_history.shape, -1)) best_arg = np.argmin(loss_history) best_loss = loss_history[best_arg] best_params = params_history[best_arg, :] return best_loss, best_params
def update_particle( ran_key, x, v, xmin, xmax, b_loc, b_swarm, w=INERTIAL_WEIGHT, acc_loc=COGNITIVE_WEIGHT, acc_swarm=SOCIAL_WEIGHT, vmax_frac=VMAX_FRAC ): xnew = x + v xnew, v = _impose_reflecting_boundary_condition(xnew, v, xmin, xmax) vnew = mc_update_velocity( ran_key, xnew, v, xmin, xmax, b_loc, b_swarm, w, acc_loc, acc_swarm, vmax_frac ) return xnew, vnew def mc_update_velocity( ran_key, x, v, xmin, xmax, b_loc, b_swarm, w=INERTIAL_WEIGHT, acc_loc=COGNITIVE_WEIGHT, acc_swarm=SOCIAL_WEIGHT, vmax_frac=VMAX_FRAC ): """Update the particle velocity Parameters ---------- ran_key : jax.random.PRNGKey JAX random seed used to generate random speeds x : ndarray of shape (n_params, ) Current position of particle xmin : ndarray of shape (n_params, ) Minimum position of particle xmax : ndarray of shape (n_params, ) Maximum position of particle v : ndarray of shape (n_params, ) Current velocity of particle b_loc : ndarray of shape (n_params, ) best point in history of particle b_swarm : ndarray of shape (n_params, ) best point in history of swarm w : float, optional inertial weight Default is INERTIAL_WEIGHT defined at top of module acc_loc : float, optional local acceleration Default is ACC_CONST defined at top of module acc_swarm : float, optional swarm acceleration Default is ACC_CONST defined at top of module Returns ------- vnew : ndarray of shape (n_params, ) New velocity of particle """ u_loc, u_swarm = jran.uniform(ran_key, shape=(2,)) return _update_velocity_kern( x, v, xmin, xmax, b_loc, b_swarm, w, acc_loc, acc_swarm, vmax_frac, u_loc, u_swarm) def _update_velocity_kern( x, v, xmin, xmax, b_loc, b_swarm, w, acc_loc, acc_swarm, vmax_frac, u_loc, u_swarm ): term1 = w * v term2 = u_loc * acc_loc * (b_loc - x) term3 = u_swarm * acc_swarm * (b_swarm - x) v = term1 + term2 + term3 vmax = _get_vmax(xmin, xmax, vmax_frac) v = _get_clipped_velocity(v, vmax) # print(f"From x={x}: local_best={b_loc}, swarm_best={b_swarm}\n" # f"v_inertia={term1}, v_cognitive={term2}, v_social={term3}", # flush=True) return v def _get_vmax(xmin, xmax, vmax_frac=VMAX_FRAC): return vmax_frac * (xmax - xmin) def _get_clipped_velocity(v, vmax): # vmag = np.sqrt(np.sum(v**2)) # if vmag > vmax: # v = v * vmax / vmag v = np.where(v > vmax, vmax, v) v = np.where(v < -vmax, -vmax, v) return v def _get_v_init(numpart, ran_key, xmin, xmax, vmax_frac=VMAX_FRAC): n_dim = xmin.size vmax = _get_vmax(xmin, xmax, vmax_frac) u_init = jran.uniform(ran_key, shape=(numpart, n_dim)) return np.array(u_init * vmax) def _impose_reflecting_boundary_condition(x, v, xmin, xmax): msk_lo = x < xmin msk_hi = x > xmax x = np.where(msk_lo, xmin, x) x = np.where(msk_hi, xmax, x) v = np.where(msk_lo | msk_hi, -v, v) return x, v def get_lhs_initial_conditions(numpart, ndim, xlo=0, xhi=1, random_cd=True, vmax_frac=VMAX_FRAC, ran_key=None): opt = "random-cd" if random_cd else None if ran_key is None: ran_key = jran.PRNGKey(987654321) xmin = np.zeros(ndim) + xlo xmax = np.zeros(ndim) + xhi x_init_key, v_init_key = jran.split(ran_key, 2) x_seed = int(jran.randint( x_init_key, (), 0, 1000000000, dtype=np.uint32)) sampler = qmc.LatinHypercube(ndim, optimization=opt, rng=x_seed) x_init = sampler.random(numpart) x_init = qmc.scale(x_init, xmin, xmax) v_init = _get_v_init(numpart, v_init_key, xmin, xmax, vmax_frac) return xmin, xmax, x_init, v_init def init_randkey(randkey) -> jax.Array: """Check that randkey is a PRNG key or create one from an int""" if isinstance(randkey, int): randkey = jran.key(randkey) else: msg = f"Invalid {type(randkey)=}: Must be int or PRNG Key" assert hasattr(randkey, "dtype"), msg assert jnp.issubdtype(randkey.dtype, jax.dtypes.prng_key), msg return randkey