Integration with multigrad
This notebook will show an example and discuss some of the complexity arising from performing kdescent in parallel with the aid of the multigrad package. We will be following an identical example to that found in the Advanced Usage section of the tutorial. All procedural differences will be flagged with a # NOTE: ... comment.
In the following script, the generate_model() function will do the same thing as before, except it will divide the sample size by the number of MPI ranks available, so that a fully-sized sample will be generated across all the ranks (and we will split the randkey between the ranks so the subsamples are not identical). Let’s save this script as kdescent-multigrad-integration.py:
"""
kdescent-multigrad-integration.py
"""
import functools
from dataclasses import dataclass
import numpy as np
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
import seaborn as sns
from mpi4py import MPI
from diffopt import kdescent
from diffopt import multigrad
comm = MPI.COMM_WORLD
model_nsample = 20_000
data_nsample = 10_000 # same volume, but undersampled below logM* < 10.5
# Generate data weighted from two mass-dependent multivariate normals
@functools.partial(jax.jit, static_argnames=["undersample", "nsample"])
def generate_model(params, randkey, undersample=False, nsample=model_nsample):
# NOTE: Divide nsample and split randkey across MPI ranks:
nsample = nsample // comm.size
randkey = jax.random.split(randkey, comm.size)[comm.rank]
# Parse all 20 parameters
# =======================
# Lower and upper bounds on log stellar mass
logmlim = params[:2]
logmlim = logmlim.at[1].add(logmlim[0] + 0.001)
# Distribution parameters at lower mass bound
mean_mmin = params[2:4]
sigma11, sigma22 = params[4:6]
maxcov = jnp.sqrt(sigma11 * sigma22)
sigma12 = params[6] * maxcov
cov_mmin = jnp.array([[sigma11, sigma12],
[sigma12, sigma22]])
qfrac_mmin = params[7]
qmean_mmin = mean_mmin + params[8:10]
qscale_mmin = params[10]
# Distribution parameters at upper mass bound
mean_mmax = params[11:13]
sigma11, sigma22 = params[13:15]
maxcov = jnp.sqrt(sigma11 * sigma22)
sigma12 = params[15] * maxcov
cov_mmax = jnp.array([[sigma11, sigma12],
[sigma12, sigma22]])
qfrac_mmax = params[16]
qmean_mmax = mean_mmax + params[17:19]
qscale_mmax = params[19]
# Generate distribution from parameters
# =====================================
key1, key2 = jax.random.split(randkey, num=2)
triangle_vals = (0, 0.5, 1) if undersample else (0, 0, 1)
logm = jax.random.triangular(key1, *triangle_vals, shape=(nsample,))
logm = logmlim[0] + logm * (logmlim[1] - logmlim[0])
# Calculate slope of mass dependence
dlogm = logmlim[1] - logmlim[0]
dmean = (mean_mmax - mean_mmin) / dlogm
dcov = (cov_mmax - cov_mmin) / dlogm
dqfrac = (qfrac_mmax - qfrac_mmin) / dlogm
dqmean = (qmean_mmax - qmean_mmin) / dlogm
dqscale = (qscale_mmax - qscale_mmin) / dlogm
# Apply mass dependence
mean_sf = mean_mmin + dmean * (logm[:, None] - logmlim[0])
cov_sf = cov_mmin + dcov * (logm[:, None, None] - logmlim[0])
mean_q = qmean_mmin + dqmean * (logm[:, None] - logmlim[0])
qscale = qscale_mmin + dqscale * (logm - logmlim[0])
cov_q = cov_sf * qscale[:, None, None] ** 2
qfrac = qfrac_mmin + dqfrac * (logm - logmlim[0])
# Generate colors from two separate multivariate normals
rz_sf, gr_sf = jax.random.multivariate_normal(key2, mean_sf, cov_sf).T
rz_q, gr_q = jax.random.multivariate_normal(key2, mean_q, cov_q).T
# Concatenate the quenched + star-forming values and assign weights
data_sf = jnp.array([rz_sf, gr_sf, logm]).T
data_q = jnp.array([rz_q, gr_q, logm]).T
data = jnp.concatenate([data_sf, data_q])
weights = jnp.concatenate([1 - qfrac, qfrac])
return data, weights
# Define "true" parameters to generate training data
truth_logmmin = 9.0
truth_logmrange = 3.0
truth_mean_mmin = jnp.array([1.4, 1.1])
truth_var_mmin = jnp.array([0.7, 0.4])
truth_corr_mmin = 0.3
truth_qfrac_mmin = 0.2
truth_qmean_mmin = jnp.array([-0.1, 1.6])
truth_qscale_mmin = 0.3
truth_mean_mmax = jnp.array([2.0, 1.6])
truth_var_mmax = jnp.array([0.5, 0.5])
truth_corr_mmax = 0.75
truth_qfrac_mmax = 0.95
truth_qmean_mmax = jnp.array([-0.6, 1.2])
truth_qscale_mmax = 1.1
bounds_truth_logmrange = [0.001, jnp.inf]
bounds_var = ([0.001, jnp.inf], [0.001, jnp.inf])
bounds_corr = [-0.999, 0.999]
bounds_qfrac = [0.0, 1.0]
bounds_qmean_gr = [0.001, jnp.inf]
bounds_qscale = [0.001, jnp.inf]
truth = jnp.array([
truth_logmmin, truth_logmrange,
*truth_mean_mmin, *truth_var_mmin, truth_corr_mmin, truth_qfrac_mmin,
*truth_qmean_mmin, truth_qscale_mmin,
*truth_mean_mmax, *truth_var_mmax, truth_corr_mmax, truth_qfrac_mmax,
*truth_qmean_mmax, truth_qscale_mmax
])
guess = jnp.array([
9.25, 2.5, *[0.0, 0.0, 1.0, 1.0, 0.0, 0.5, 0.0, 1.0, 1.0]*2
])
bounds = [
None, bounds_truth_logmrange,
*[None, None, *bounds_var, bounds_corr, bounds_qfrac,
None, bounds_qmean_gr, bounds_qscale]*2
]
# Generate training data from the truth parameters we just defined
truth_randkey = jax.random.key(43)
training_x_weighted, training_w = generate_model(
truth, truth_randkey, undersample=True, nsample=data_nsample)
# NOTE: Every rank must be aware of the FULL training data, so we must gather:
training_x_weighted = jnp.concatenate(comm.allgather(training_x_weighted))
training_w = jnp.concatenate(comm.allgather(training_w))
# KDescent allows weighted training data, but to make this more realistic,
# let's use weighted sampling instead
training_selection = jax.random.uniform(
jax.random.split(truth_randkey)[0], (len(training_w),)) < training_w
training_x = training_x_weighted[training_selection]
# Define plotting function
lowmass_cut = [9.0, 9.5]
midmass_cut = [10.25, 10.75]
highmass_cut = [11.5, 12.0]
is_lowmass = ((lowmass_cut[0] < training_x_weighted[:, 2])
& (training_x_weighted[:, 2] < lowmass_cut[1]))
is_midmass = ((midmass_cut[0] < training_x_weighted[:, 2])
& (training_x_weighted[:, 2] < midmass_cut[1]))
is_highmass = ((highmass_cut[0] < training_x_weighted[:, 2])
& (training_x_weighted[:, 2] < highmass_cut[1]))
training_w_lowmass = training_w * is_lowmass
training_w_midmass = training_w * is_midmass
training_w_highmass = training_w * is_highmass
is_noweight_lowmass = (
(lowmass_cut[0] < training_x[:, 2])
& (training_x[:, 2] < lowmass_cut[1]))
is_noweight_midmass = (
(midmass_cut[0] < training_x[:, 2])
& (training_x[:, 2] < midmass_cut[1]))
is_noweight_highmass = (
(highmass_cut[0] < training_x[:, 2])
& (training_x[:, 2] < highmass_cut[1]))
def generate_model_into_mass_bins(params, randkey):
# NOTE: Gather data from each rank (since this is for plotting only)
model_x, model_w = generate_model(params, randkey=randkey)
model_x = jnp.concatenate(comm.allgather(model_x))
model_w = jnp.concatenate(comm.allgather(model_w))
is_low = ((lowmass_cut[0] < model_x[:, 2])
& (model_x[:, 2] < lowmass_cut[1]))
is_mid = ((midmass_cut[0] < model_x[:, 2])
& (model_x[:, 2] < midmass_cut[1]))
is_high = ((highmass_cut[0] < model_x[:, 2])
& (model_x[:, 2] < highmass_cut[1]))
return (model_x, model_x[is_low], model_x[is_mid], model_x[is_high],
model_w, model_w[is_low], model_w[is_mid], model_w[is_high])
def make_sumstat_plot(params, txt="", fig=None, prev_layers=None):
(modall, modlow, modmid, modhigh,
w_all, w_low, w_mid, w_high) = generate_model_into_mass_bins(
params, jax.random.key(13))
if prev_layers is not None:
for layer in prev_layers:
layer.remove()
fig = plt.figure(figsize=(10, 9)) if fig is None else fig
ax = fig.add_subplot(221) if len(fig.axes) < 4 else fig.axes[0]
ax.hist(training_x_weighted[:, 2], bins=50, color="red",
weights=training_w)
_, bins, hist1 = ax.hist(
modall[:, 2], color="grey", bins=50, alpha=0.9, weights=w_all)
hist2 = ax.hist(modlow[:, 2], bins=list(bins), color="C0",
alpha=0.9, weights=w_low)[-1]
hist3 = ax.hist(modmid[:, 2], bins=list(bins), color="C0",
alpha=0.9, weights=w_mid)[-1]
hist4 = ax.hist(modhigh[:, 2], bins=list(bins), color="C0",
alpha=0.9, weights=w_high)[-1]
ax.set_xlabel("$\\log M_\\ast$", fontsize=14)
text1 = ax.text(
0.98, 0.98, "Training data", color="red", va="top", ha="right",
fontsize=14, transform=ax.transAxes)
text2 = ax.text(
0.98, 0.91, txt, color="blue", va="top", ha="right",
fontsize=14, transform=ax.transAxes)
ax = fig.add_subplot(222) if len(fig.axes) < 4 else fig.axes[1]
hex1 = ax.hexbin(*modlow[:, :2].T, mincnt=1,
C=w_low, reduce_C_function=np.sum,
norm=plt.matplotlib.colors.LogNorm())
if prev_layers is None:
sns.kdeplot(
{"$r - z$": training_x_weighted[is_lowmass][:, 0],
"$g - r$": training_x_weighted[is_lowmass][:, 1]},
weights=training_w[is_lowmass],
x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
ax.set_xlabel("$r - z$", fontsize=14)
ax.set_ylabel("$g - r$", fontsize=14)
text3 = ax.text(
0.02, 0.02, f"${lowmass_cut[0]} < \\log M_\\ast < {lowmass_cut[1]}$",
fontsize=14, transform=ax.transAxes)
ax = fig.add_subplot(223, sharex=ax, sharey=ax) if len(
fig.axes) < 4 else fig.axes[2]
hex2 = ax.hexbin(*modmid[:, :2].T, mincnt=1,
C=w_mid, reduce_C_function=np.sum,
norm=plt.matplotlib.colors.LogNorm())
if prev_layers is None:
sns.kdeplot(
{"$r - z$": training_x_weighted[is_midmass][:, 0],
"$g - r$": training_x_weighted[is_midmass][:, 1]},
weights=training_w[is_midmass],
x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
ax.set_xlabel("$r - z$", fontsize=14)
ax.set_ylabel("$g - r$", fontsize=14)
text4 = ax.text(
0.02, 0.02, f"${midmass_cut[0]} < \\log M_\\ast < {midmass_cut[1]}$",
fontsize=14, transform=ax.transAxes)
ax = fig.add_subplot(224, sharex=ax, sharey=ax) if len(
fig.axes) < 4 else fig.axes[3]
hex3 = ax.hexbin(*modhigh[:, :2].T, mincnt=1,
C=w_high, reduce_C_function=np.sum,
norm=plt.matplotlib.colors.LogNorm())
if prev_layers is None:
sns.kdeplot(
{"$r - z$": training_x_weighted[is_highmass][:, 0],
"$g - r$": training_x_weighted[is_highmass][:, 1]},
weights=training_w[is_highmass],
x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
ax.set_xlabel("$r - z$", fontsize=14)
ax.set_ylabel("$g - r$", fontsize=14)
text5 = ax.text(
0.02, 0.02, f"${highmass_cut[0]} < \\log M_\\ast < {highmass_cut[1]}$",
fontsize=14, transform=ax.transAxes)
ax.set_xlim(-4, 7)
ax.set_ylim(-4, 7)
return [hex1, hex2, hex3, hist1, hist2, hist3, hist4,
text1, text2, text3, text4, text5]
# Define loss function comparing PDF(g-r, r-z | M*) and its Fourier pair
# NOTE: Since we plan on jitting, we can't pass comm=comm to our KCalcs.
# Instead, we will be careful to call the compare_*_counts() methods with
# identical randkeys on each MPI rank!
ktrain_lowmass = kdescent.KPretrainer.from_training_data(
training_x[is_noweight_lowmass, :2],
bandwidth_factor=0.3, fourier_range_factor=3.0,
)
ktrain_midmass = kdescent.KPretrainer.from_training_data(
training_x[is_noweight_midmass, :2],
bandwidth_factor=0.3, fourier_range_factor=3.0,
)
ktrain_highmass = kdescent.KPretrainer.from_training_data(
training_x[is_noweight_highmass, :2],
bandwidth_factor=0.3, fourier_range_factor=3.0,
)
kcalc_lowmass = kdescent.KCalc(ktrain_lowmass)
kcalc_midmass = kdescent.KCalc(ktrain_midmass)
kcalc_highmass = kdescent.KCalc(ktrain_highmass)
# Differentiable alternative hard binning in the loss function:
@jax.jit
def soft_tophat(x, low, high, squish=25.0):
"""Approximately return 1 when `low < x < high`, else return 0"""
width = (high - low) / squish
left = jax.nn.sigmoid((x - low) / width)
right = jax.nn.sigmoid((high - x) / width)
return left * right
# NOTE: For multigrad, we have to explicitly define sumstats_from_params()
# and loss_from_sumstats() to replace the old lossfunc()
@jax.jit
def sumstats_from_params(params, randkey):
key1, *keys = jax.random.split(randkey, 7)
model_x, model_w = generate_model(params, randkey=key1)
weight_low = soft_tophat(model_x[:, 2], *lowmass_cut) * model_w
weight_mid = soft_tophat(model_x[:, 2], *midmass_cut) * model_w
weight_high = soft_tophat(model_x[:, 2], *highmass_cut) * model_w
model_low_counts, truth_low_counts = kcalc_lowmass.compare_kde_counts(
keys[0], model_x[:, :2], weight_low)
model_mid_counts, truth_mid_counts = kcalc_midmass.compare_kde_counts(
keys[1], model_x[:, :2], weight_mid)
model_high_counts, truth_high_counts = kcalc_highmass.compare_kde_counts(
keys[2], model_x[:, :2], weight_high)
model_low_fcounts, truth_low_fcounts = kcalc_lowmass.compare_fourier_counts(
keys[3], model_x[:, :2], weight_low)
model_mid_fcounts, truth_mid_fcounts = kcalc_midmass.compare_fourier_counts(
keys[4], model_x[:, :2], weight_mid)
model_high_fcounts, truth_high_fcounts = kcalc_highmass.compare_fourier_counts(
keys[5], model_x[:, :2], weight_high)
# NOTE: "Sumstats" are raw counts so that they can be summed across ranks
sumstats = jnp.array([
*model_low_counts, *model_low_fcounts, weight_low.sum(),
*model_mid_counts, *model_mid_fcounts, weight_mid.sum(),
*model_high_counts, *model_high_fcounts, weight_high.sum(),
])
# NOTE: To prevent *truth* counts being summed, pass them as "auxiliary"
sumstats_aux = jnp.array([
*truth_low_counts, *truth_low_fcounts,
*truth_mid_counts, *truth_mid_fcounts,
*truth_high_counts, *truth_high_fcounts,
])
return sumstats, sumstats_aux
@jax.jit
def loss_from_sumstats(sumstats, sumstats_aux):
# NOTE: Unpack sumstats (raw model counts per kernel + total weight sums)
i = 0
model_low_counts = sumstats[ # slice [0:20]
i:(i := i + kcalc_lowmass.num_eval_kernels)]
model_low_fcounts = sumstats[ # slice [20:40]
i:(i := i + kcalc_lowmass.num_eval_fourier_positions)]
weight_low_sum = sumstats[i:(i := i + 1)][0] # slice [40:41][0]
model_mid_counts = sumstats[ # slice [41:61]
i:(i := i + kcalc_midmass.num_eval_kernels)]
model_mid_fcounts = sumstats[ # slice [61:81]
i:(i := i + kcalc_midmass.num_eval_fourier_positions)]
weight_mid_sum = sumstats[i:(i := i + 1)][0] # slice [81:82][0]
model_high_counts = sumstats[ # slice [82:102]
i:(i := i + kcalc_highmass.num_eval_kernels)]
model_high_fcounts = sumstats[ # slice [102:122]
i:(i := i + kcalc_highmass.num_eval_fourier_positions)]
weight_high_sum = sumstats[i:(i := i + 1)][0] # slice [122:123][0]
# NOTE: Unpack sumstats_aux (raw truth counts per kernel)
i = 0
truth_low_counts = sumstats_aux[ # slice [0:20]
i:(i := i + kcalc_lowmass.num_eval_kernels)]
truth_low_fcounts = sumstats_aux[ # slice [20:40]
i:(i := i + kcalc_lowmass.num_eval_fourier_positions)]
truth_mid_counts = sumstats_aux[ # slice [40:60]
i:(i := i + kcalc_midmass.num_eval_kernels)]
truth_mid_fcounts = sumstats_aux[ # slice [60:80]
i:(i := i + kcalc_midmass.num_eval_fourier_positions)]
truth_high_counts = sumstats_aux[ # slice [80:100]
i:(i := i + kcalc_highmass.num_eval_kernels)]
truth_high_fcounts = sumstats_aux[ # slice [100:120]
i:(i := i + kcalc_highmass.num_eval_fourier_positions)]
# Convert counts to conditional prob: P(krnl | M*) = N(krnl & M*) / N(M*)
model_low_condprob = model_low_counts / (weight_low_sum + 1e-10)
model_mid_condprob = model_mid_counts / (weight_mid_sum + 1e-10)
model_high_condprob = model_high_counts / (weight_high_sum + 1e-10)
truth_low_condprob = truth_low_counts / (training_w_lowmass.sum() + 1e-10)
truth_mid_condprob = truth_mid_counts / (training_w_midmass.sum() + 1e-10)
truth_high_condprob = truth_high_counts / (
training_w_highmass.sum() + 1e-10)
# Convert Fourier counts to "conditional" ECF analogously
model_low_ecf = model_low_fcounts / (weight_low_sum + 1e-10)
model_mid_ecf = model_mid_fcounts / (weight_mid_sum + 1e-10)
model_high_ecf = model_high_fcounts / (weight_high_sum + 1e-10)
truth_low_ecf = truth_low_fcounts / (training_w_lowmass.sum() + 1e-10)
truth_mid_ecf = truth_mid_fcounts / (training_w_midmass.sum() + 1e-10)
truth_high_ecf = truth_high_fcounts / (training_w_highmass.sum() + 1e-10)
# One constraint on number density at the highest stellar mass bin
volume = 100.0
model_massfunc = jnp.array([weight_high_sum,]) / volume
truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume
# Must abs() the Fourier residuals so that the loss is real
# NOTE: We even have to abs() the PDF residuals due to multigrad
# combining all sumstats into a single complex-typed array
sqerrs = jnp.abs(jnp.concatenate([
(model_low_condprob - truth_low_condprob)**2,
(model_mid_condprob - truth_mid_condprob)**2,
(model_high_condprob - truth_high_condprob)**2,
(model_low_ecf - truth_low_ecf)**2,
(model_mid_ecf - truth_mid_ecf)**2,
(model_high_ecf - truth_high_ecf)**2,
(model_massfunc - truth_massfunc)**2,
]))
return jnp.mean(sqerrs)
# NOTE: Define multigrad class using the sumstats + loss funcs we just defined
@dataclass
class MyModel(multigrad.OnePointModel):
sumstats_func_has_aux: bool = True # override param default set by parent
def calc_partial_sumstats_from_params(self, params, randkey):
# NOTE: sumstats will automatically be summed over all MPI ranks
# before getting passed to calc_loss_from_sumstats. However,
# sumstats_aux will be passed directly without MPI communication
sumstats, sumstats_aux = sumstats_from_params(params, randkey)
return sumstats, sumstats_aux
def calc_loss_from_sumstats(self, sumstats, sumstats_aux, randkey=None):
# NOTE: randkey kwarg must be accepted by BOTH functions or NEITHER
# However, we have no need for it in the loss function
del randkey
loss = loss_from_sumstats(sumstats, sumstats_aux)
return loss
if __name__ == "__main__":
# Run gradient descent (nearly identical to pure kdescent)
model = MyModel()
nsteps = 600
adam_params, _ = model.run_adam(
guess, nsteps=nsteps, param_bounds=bounds,
learning_rate=0.05, randkey=12345)
if not comm.rank:
# Print results and save figure on root rank only
print("Best fit params =", adam_params[-1])
fig = plt.figure(figsize=(20, 9), layout="constrained")
fig.set_facecolor("0.05")
figs = fig.subfigures(1, 2, wspace=0.004)
figs[0].set_facecolor("white")
figs[1].set_facecolor("white")
make_sumstat_plot(
guess, txt="Initial guess", fig=figs[0])
make_sumstat_plot(
adam_params[-1],
txt=f"Solution after {nsteps} evaluations", fig=figs[1])
plt.savefig("kdescent-multigrad-results.png")
else:
# All other ranks need to do this for make_sumstat_plot() to work...
generate_model_into_mass_bins(guess, jax.random.key(13))
generate_model_into_mass_bins(adam_params[-1], jax.random.key(13))
Now let’s run this with four MPI ranks. Executing mpiexec -n 2 python kdescent-multigrad-integration.py yields the following results (about 2x speedup on my laptop):
Adam Gradient Descent Progress: 100%|██████████| 601/601 [04:06<00:00, 2.44it/s]
Best fit params = [ 9.258453 2.6762085 1.486586 1.1471022 0.6167415 0.43723965
0.38919714 0.21205594 -0.2331485 1.5793908 0.40174854 1.7807314
1.7946088 0.56626403 0.58664 0.73132914 0.8944174 -0.34618914
1.080119 0.9135933 ]
