{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Integration with `multigrad`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook will show an example and discuss some of the complexity arising from performing `kdescent` in parallel with the aid of the `multigrad` package. We will be following an identical example to that found in the [Advanced Usage](./intro.ipynb#Advanced-Usage) section of the tutorial. All procedural differences will be flagged with a `# NOTE: ...` comment.\n", "\n", "In the following script, the `generate_model()` function will do the same thing as before, except it will divide the sample size by the number of MPI ranks available, so that a fully-sized sample will be generated across all the ranks (and we will split the randkey between the ranks so the subsamples are not identical). Let's save this script as `kdescent-multigrad-integration.py`:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "\"\"\"\n", "kdescent-multigrad-integration.py\n", "\"\"\"\n", "\n", "import functools\n", "from dataclasses import dataclass\n", "import numpy as np\n", "import jax.numpy as jnp\n", "import jax.random\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from mpi4py import MPI\n", "\n", "from diffopt import kdescent\n", "from diffopt import multigrad\n", "\n", "comm = MPI.COMM_WORLD\n", "\n", "model_nsample = 20_000\n", "data_nsample = 10_000 # same volume, but undersampled below logM* < 10.5\n", "\n", "# Generate data weighted from two mass-dependent multivariate normals\n", "\n", "\n", "@functools.partial(jax.jit, static_argnames=[\"undersample\", \"nsample\"])\n", "def generate_model(params, randkey, undersample=False, nsample=model_nsample):\n", " # NOTE: Divide nsample and split randkey across MPI ranks:\n", " nsample = nsample // comm.size\n", " randkey = jax.random.split(randkey, comm.size)[comm.rank]\n", "\n", " # Parse all 20 parameters\n", " # =======================\n", " # Lower and upper bounds on log stellar mass\n", " logmlim = params[:2]\n", " logmlim = logmlim.at[1].add(logmlim[0] + 0.001)\n", "\n", " # Distribution parameters at lower mass bound\n", " mean_mmin = params[2:4]\n", " sigma11, sigma22 = params[4:6]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[6] * maxcov\n", " cov_mmin = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmin = params[7]\n", " qmean_mmin = mean_mmin + params[8:10]\n", " qscale_mmin = params[10]\n", "\n", " # Distribution parameters at upper mass bound\n", " mean_mmax = params[11:13]\n", " sigma11, sigma22 = params[13:15]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[15] * maxcov\n", " cov_mmax = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmax = params[16]\n", " qmean_mmax = mean_mmax + params[17:19]\n", " qscale_mmax = params[19]\n", "\n", " # Generate distribution from parameters\n", " # =====================================\n", " key1, key2 = jax.random.split(randkey, num=2)\n", " triangle_vals = (0, 0.5, 1) if undersample else (0, 0, 1)\n", " logm = jax.random.triangular(key1, *triangle_vals, shape=(nsample,))\n", " logm = logmlim[0] + logm * (logmlim[1] - logmlim[0])\n", " # Calculate slope of mass dependence\n", " dlogm = logmlim[1] - logmlim[0]\n", " dmean = (mean_mmax - mean_mmin) / dlogm\n", " dcov = (cov_mmax - cov_mmin) / dlogm\n", " dqfrac = (qfrac_mmax - qfrac_mmin) / dlogm\n", " dqmean = (qmean_mmax - qmean_mmin) / dlogm\n", " dqscale = (qscale_mmax - qscale_mmin) / dlogm\n", " # Apply mass dependence\n", " mean_sf = mean_mmin + dmean * (logm[:, None] - logmlim[0])\n", " cov_sf = cov_mmin + dcov * (logm[:, None, None] - logmlim[0])\n", " mean_q = qmean_mmin + dqmean * (logm[:, None] - logmlim[0])\n", " qscale = qscale_mmin + dqscale * (logm - logmlim[0])\n", " cov_q = cov_sf * qscale[:, None, None] ** 2\n", " qfrac = qfrac_mmin + dqfrac * (logm - logmlim[0])\n", "\n", " # Generate colors from two separate multivariate normals\n", " rz_sf, gr_sf = jax.random.multivariate_normal(key2, mean_sf, cov_sf).T\n", " rz_q, gr_q = jax.random.multivariate_normal(key2, mean_q, cov_q).T\n", " # Concatenate the quenched + star-forming values and assign weights\n", " data_sf = jnp.array([rz_sf, gr_sf, logm]).T\n", " data_q = jnp.array([rz_q, gr_q, logm]).T\n", " data = jnp.concatenate([data_sf, data_q])\n", " weights = jnp.concatenate([1 - qfrac, qfrac])\n", " return data, weights\n", "\n", "\n", "# Define \"true\" parameters to generate training data\n", "truth_logmmin = 9.0\n", "truth_logmrange = 3.0\n", "\n", "truth_mean_mmin = jnp.array([1.4, 1.1])\n", "truth_var_mmin = jnp.array([0.7, 0.4])\n", "truth_corr_mmin = 0.3\n", "truth_qfrac_mmin = 0.2\n", "truth_qmean_mmin = jnp.array([-0.1, 1.6])\n", "truth_qscale_mmin = 0.3\n", "\n", "truth_mean_mmax = jnp.array([2.0, 1.6])\n", "truth_var_mmax = jnp.array([0.5, 0.5])\n", "truth_corr_mmax = 0.75\n", "truth_qfrac_mmax = 0.95\n", "truth_qmean_mmax = jnp.array([-0.6, 1.2])\n", "truth_qscale_mmax = 1.1\n", "\n", "bounds_truth_logmrange = [0.001, jnp.inf]\n", "bounds_var = ([0.001, jnp.inf], [0.001, jnp.inf])\n", "bounds_corr = [-0.999, 0.999]\n", "bounds_qfrac = [0.0, 1.0]\n", "bounds_qmean_gr = [0.001, jnp.inf]\n", "bounds_qscale = [0.001, jnp.inf]\n", "\n", "truth = jnp.array([\n", " truth_logmmin, truth_logmrange,\n", " *truth_mean_mmin, *truth_var_mmin, truth_corr_mmin, truth_qfrac_mmin,\n", " *truth_qmean_mmin, truth_qscale_mmin,\n", " *truth_mean_mmax, *truth_var_mmax, truth_corr_mmax, truth_qfrac_mmax,\n", " *truth_qmean_mmax, truth_qscale_mmax\n", "])\n", "guess = jnp.array([\n", " 9.25, 2.5, *[0.0, 0.0, 1.0, 1.0, 0.0, 0.5, 0.0, 1.0, 1.0]*2\n", "])\n", "bounds = [\n", " None, bounds_truth_logmrange,\n", " *[None, None, *bounds_var, bounds_corr, bounds_qfrac,\n", " None, bounds_qmean_gr, bounds_qscale]*2\n", "]\n", "\n", "# Generate training data from the truth parameters we just defined\n", "truth_randkey = jax.random.key(43)\n", "training_x_weighted, training_w = generate_model(\n", " truth, truth_randkey, undersample=True, nsample=data_nsample)\n", "\n", "# NOTE: Every rank must be aware of the FULL training data, so we must gather:\n", "training_x_weighted = jnp.concatenate(comm.allgather(training_x_weighted))\n", "training_w = jnp.concatenate(comm.allgather(training_w))\n", "\n", "# KDescent allows weighted training data, but to make this more realistic,\n", "# let's use weighted sampling instead\n", "training_selection = jax.random.uniform(\n", " jax.random.split(truth_randkey)[0], (len(training_w),)) < training_w\n", "training_x = training_x_weighted[training_selection]\n", "\n", "# Define plotting function\n", "lowmass_cut = [9.0, 9.5]\n", "midmass_cut = [10.25, 10.75]\n", "highmass_cut = [11.5, 12.0]\n", "is_lowmass = ((lowmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < lowmass_cut[1]))\n", "is_midmass = ((midmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < midmass_cut[1]))\n", "is_highmass = ((highmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < highmass_cut[1]))\n", "training_w_lowmass = training_w * is_lowmass\n", "training_w_midmass = training_w * is_midmass\n", "training_w_highmass = training_w * is_highmass\n", "is_noweight_lowmass = (\n", " (lowmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < lowmass_cut[1]))\n", "is_noweight_midmass = (\n", " (midmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < midmass_cut[1]))\n", "is_noweight_highmass = (\n", " (highmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < highmass_cut[1]))\n", "\n", "\n", "def generate_model_into_mass_bins(params, randkey):\n", " # NOTE: Gather data from each rank (since this is for plotting only)\n", " model_x, model_w = generate_model(params, randkey=randkey)\n", " model_x = jnp.concatenate(comm.allgather(model_x))\n", " model_w = jnp.concatenate(comm.allgather(model_w))\n", "\n", " is_low = ((lowmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < lowmass_cut[1]))\n", " is_mid = ((midmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < midmass_cut[1]))\n", " is_high = ((highmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < highmass_cut[1]))\n", " return (model_x, model_x[is_low], model_x[is_mid], model_x[is_high],\n", " model_w, model_w[is_low], model_w[is_mid], model_w[is_high])\n", "\n", "\n", "def make_sumstat_plot(params, txt=\"\", fig=None, prev_layers=None):\n", " (modall, modlow, modmid, modhigh,\n", " w_all, w_low, w_mid, w_high) = generate_model_into_mass_bins(\n", " params, jax.random.key(13))\n", " if prev_layers is not None:\n", " for layer in prev_layers:\n", " layer.remove()\n", "\n", " fig = plt.figure(figsize=(10, 9)) if fig is None else fig\n", " ax = fig.add_subplot(221) if len(fig.axes) < 4 else fig.axes[0]\n", " ax.hist(training_x_weighted[:, 2], bins=50, color=\"red\",\n", " weights=training_w)\n", " _, bins, hist1 = ax.hist(\n", " modall[:, 2], color=\"grey\", bins=50, alpha=0.9, weights=w_all)\n", " hist2 = ax.hist(modlow[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_low)[-1]\n", " hist3 = ax.hist(modmid[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_mid)[-1]\n", " hist4 = ax.hist(modhigh[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_high)[-1]\n", " ax.set_xlabel(\"$\\\\log M_\\\\ast$\", fontsize=14)\n", " text1 = ax.text(\n", " 0.98, 0.98, \"Training data\", color=\"red\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", " text2 = ax.text(\n", " 0.98, 0.91, txt, color=\"blue\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(222) if len(fig.axes) < 4 else fig.axes[1]\n", " hex1 = ax.hexbin(*modlow[:, :2].T, mincnt=1,\n", " C=w_low, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_lowmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_lowmass][:, 1]},\n", " weights=training_w[is_lowmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text3 = ax.text(\n", " 0.02, 0.02, f\"${lowmass_cut[0]} < \\\\log M_\\\\ast < {lowmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(223, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[2]\n", " hex2 = ax.hexbin(*modmid[:, :2].T, mincnt=1,\n", " C=w_mid, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_midmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_midmass][:, 1]},\n", " weights=training_w[is_midmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text4 = ax.text(\n", " 0.02, 0.02, f\"${midmass_cut[0]} < \\\\log M_\\\\ast < {midmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(224, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[3]\n", " hex3 = ax.hexbin(*modhigh[:, :2].T, mincnt=1,\n", " C=w_high, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_highmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_highmass][:, 1]},\n", " weights=training_w[is_highmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text5 = ax.text(\n", " 0.02, 0.02, f\"${highmass_cut[0]} < \\\\log M_\\\\ast < {highmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", " ax.set_xlim(-4, 7)\n", " ax.set_ylim(-4, 7)\n", " return [hex1, hex2, hex3, hist1, hist2, hist3, hist4,\n", " text1, text2, text3, text4, text5]\n", "\n", "\n", "# Define loss function comparing PDF(g-r, r-z | M*) and its Fourier pair\n", "\n", "# NOTE: Since we plan on jitting, we can't pass comm=comm to our KCalcs.\n", "# Instead, we will be careful to call the compare_*_counts() methods with\n", "# identical randkeys on each MPI rank!\n", "ktrain_lowmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_lowmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_midmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_midmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_highmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_highmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "kcalc_lowmass = kdescent.KCalc(ktrain_lowmass)\n", "kcalc_midmass = kdescent.KCalc(ktrain_midmass)\n", "kcalc_highmass = kdescent.KCalc(ktrain_highmass)\n", "\n", "\n", "# Differentiable alternative hard binning in the loss function:\n", "@jax.jit\n", "def soft_tophat(x, low, high, squish=25.0):\n", " \"\"\"Approximately return 1 when `low < x < high`, else return 0\"\"\"\n", " width = (high - low) / squish\n", " left = jax.nn.sigmoid((x - low) / width)\n", " right = jax.nn.sigmoid((high - x) / width)\n", " return left * right\n", "\n", "\n", "# NOTE: For multigrad, we have to explicitly define sumstats_from_params()\n", "# and loss_from_sumstats() to replace the old lossfunc()\n", "@jax.jit\n", "def sumstats_from_params(params, randkey):\n", " key1, *keys = jax.random.split(randkey, 7)\n", "\n", " model_x, model_w = generate_model(params, randkey=key1)\n", " weight_low = soft_tophat(model_x[:, 2], *lowmass_cut) * model_w\n", " weight_mid = soft_tophat(model_x[:, 2], *midmass_cut) * model_w\n", " weight_high = soft_tophat(model_x[:, 2], *highmass_cut) * model_w\n", "\n", " model_low_counts, truth_low_counts = kcalc_lowmass.compare_kde_counts(\n", " keys[0], model_x[:, :2], weight_low)\n", " model_mid_counts, truth_mid_counts = kcalc_midmass.compare_kde_counts(\n", " keys[1], model_x[:, :2], weight_mid)\n", " model_high_counts, truth_high_counts = kcalc_highmass.compare_kde_counts(\n", " keys[2], model_x[:, :2], weight_high)\n", "\n", " model_low_fcounts, truth_low_fcounts = kcalc_lowmass.compare_fourier_counts(\n", " keys[3], model_x[:, :2], weight_low)\n", " model_mid_fcounts, truth_mid_fcounts = kcalc_midmass.compare_fourier_counts(\n", " keys[4], model_x[:, :2], weight_mid)\n", " model_high_fcounts, truth_high_fcounts = kcalc_highmass.compare_fourier_counts(\n", " keys[5], model_x[:, :2], weight_high)\n", "\n", " # NOTE: \"Sumstats\" are raw counts so that they can be summed across ranks\n", " sumstats = jnp.array([\n", " *model_low_counts, *model_low_fcounts, weight_low.sum(),\n", " *model_mid_counts, *model_mid_fcounts, weight_mid.sum(),\n", " *model_high_counts, *model_high_fcounts, weight_high.sum(),\n", " ])\n", " # NOTE: To prevent *truth* counts being summed, pass them as \"auxiliary\"\n", " sumstats_aux = jnp.array([\n", " *truth_low_counts, *truth_low_fcounts,\n", " *truth_mid_counts, *truth_mid_fcounts,\n", " *truth_high_counts, *truth_high_fcounts,\n", " ])\n", " return sumstats, sumstats_aux\n", "\n", "\n", "@jax.jit\n", "def loss_from_sumstats(sumstats, sumstats_aux):\n", " # NOTE: Unpack sumstats (raw model counts per kernel + total weight sums)\n", " i = 0\n", " model_low_counts = sumstats[ # slice [0:20]\n", " i:(i := i + kcalc_lowmass.num_eval_kernels)]\n", " model_low_fcounts = sumstats[ # slice [20:40]\n", " i:(i := i + kcalc_lowmass.num_eval_fourier_positions)]\n", " weight_low_sum = sumstats[i:(i := i + 1)][0] # slice [40:41][0]\n", "\n", " model_mid_counts = sumstats[ # slice [41:61]\n", " i:(i := i + kcalc_midmass.num_eval_kernels)]\n", " model_mid_fcounts = sumstats[ # slice [61:81]\n", " i:(i := i + kcalc_midmass.num_eval_fourier_positions)]\n", " weight_mid_sum = sumstats[i:(i := i + 1)][0] # slice [81:82][0]\n", "\n", " model_high_counts = sumstats[ # slice [82:102]\n", " i:(i := i + kcalc_highmass.num_eval_kernels)]\n", " model_high_fcounts = sumstats[ # slice [102:122]\n", " i:(i := i + kcalc_highmass.num_eval_fourier_positions)]\n", " weight_high_sum = sumstats[i:(i := i + 1)][0] # slice [122:123][0]\n", "\n", " # NOTE: Unpack sumstats_aux (raw truth counts per kernel)\n", " i = 0\n", " truth_low_counts = sumstats_aux[ # slice [0:20]\n", " i:(i := i + kcalc_lowmass.num_eval_kernels)]\n", " truth_low_fcounts = sumstats_aux[ # slice [20:40]\n", " i:(i := i + kcalc_lowmass.num_eval_fourier_positions)]\n", "\n", " truth_mid_counts = sumstats_aux[ # slice [40:60]\n", " i:(i := i + kcalc_midmass.num_eval_kernels)]\n", " truth_mid_fcounts = sumstats_aux[ # slice [60:80]\n", " i:(i := i + kcalc_midmass.num_eval_fourier_positions)]\n", "\n", " truth_high_counts = sumstats_aux[ # slice [80:100]\n", " i:(i := i + kcalc_highmass.num_eval_kernels)]\n", " truth_high_fcounts = sumstats_aux[ # slice [100:120]\n", " i:(i := i + kcalc_highmass.num_eval_fourier_positions)]\n", "\n", " # Convert counts to conditional prob: P(krnl | M*) = N(krnl & M*) / N(M*)\n", " model_low_condprob = model_low_counts / (weight_low_sum + 1e-10)\n", " model_mid_condprob = model_mid_counts / (weight_mid_sum + 1e-10)\n", " model_high_condprob = model_high_counts / (weight_high_sum + 1e-10)\n", " truth_low_condprob = truth_low_counts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_condprob = truth_mid_counts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_condprob = truth_high_counts / (\n", " training_w_highmass.sum() + 1e-10)\n", " # Convert Fourier counts to \"conditional\" ECF analogously\n", " model_low_ecf = model_low_fcounts / (weight_low_sum + 1e-10)\n", " model_mid_ecf = model_mid_fcounts / (weight_mid_sum + 1e-10)\n", " model_high_ecf = model_high_fcounts / (weight_high_sum + 1e-10)\n", " truth_low_ecf = truth_low_fcounts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_ecf = truth_mid_fcounts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_ecf = truth_high_fcounts / (training_w_highmass.sum() + 1e-10)\n", "\n", " # One constraint on number density at the highest stellar mass bin\n", " volume = 100.0\n", " model_massfunc = jnp.array([weight_high_sum,]) / volume\n", " truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume\n", "\n", " # Must abs() the Fourier residuals so that the loss is real\n", " # NOTE: We even have to abs() the PDF residuals due to multigrad\n", " # combining all sumstats into a single complex-typed array\n", " sqerrs = jnp.abs(jnp.concatenate([\n", " (model_low_condprob - truth_low_condprob)**2,\n", " (model_mid_condprob - truth_mid_condprob)**2,\n", " (model_high_condprob - truth_high_condprob)**2,\n", " (model_low_ecf - truth_low_ecf)**2,\n", " (model_mid_ecf - truth_mid_ecf)**2,\n", " (model_high_ecf - truth_high_ecf)**2,\n", " (model_massfunc - truth_massfunc)**2,\n", " ]))\n", "\n", " return jnp.mean(sqerrs)\n", "\n", "# NOTE: Define multigrad class using the sumstats + loss funcs we just defined\n", "\n", "\n", "@dataclass\n", "class MyModel(multigrad.OnePointModel):\n", " sumstats_func_has_aux: bool = True # override param default set by parent\n", "\n", " def calc_partial_sumstats_from_params(self, params, randkey):\n", " # NOTE: sumstats will automatically be summed over all MPI ranks\n", " # before getting passed to calc_loss_from_sumstats. However,\n", " # sumstats_aux will be passed directly without MPI communication\n", " sumstats, sumstats_aux = sumstats_from_params(params, randkey)\n", " return sumstats, sumstats_aux\n", "\n", " def calc_loss_from_sumstats(self, sumstats, sumstats_aux, randkey=None):\n", " # NOTE: randkey kwarg must be accepted by BOTH functions or NEITHER\n", " # However, we have no need for it in the loss function\n", " del randkey\n", "\n", " loss = loss_from_sumstats(sumstats, sumstats_aux)\n", " return loss\n", "\n", "\n", "if __name__ == \"__main__\":\n", " # Run gradient descent (nearly identical to pure kdescent)\n", " model = MyModel()\n", " nsteps = 600\n", " adam_params, _ = model.run_adam(\n", " guess, nsteps=nsteps, param_bounds=bounds,\n", " learning_rate=0.05, randkey=12345)\n", "\n", " if not comm.rank:\n", " # Print results and save figure on root rank only\n", " print(\"Best fit params =\", adam_params[-1])\n", "\n", " fig = plt.figure(figsize=(20, 9), layout=\"constrained\")\n", " fig.set_facecolor(\"0.05\")\n", " figs = fig.subfigures(1, 2, wspace=0.004)\n", " figs[0].set_facecolor(\"white\")\n", " figs[1].set_facecolor(\"white\")\n", " make_sumstat_plot(\n", " guess, txt=\"Initial guess\", fig=figs[0])\n", " make_sumstat_plot(\n", " adam_params[-1],\n", " txt=f\"Solution after {nsteps} evaluations\", fig=figs[1])\n", " plt.savefig(\"kdescent-multigrad-results.png\")\n", " else:\n", " # All other ranks need to do this for make_sumstat_plot() to work...\n", " generate_model_into_mass_bins(guess, jax.random.key(13))\n", " generate_model_into_mass_bins(adam_params[-1], jax.random.key(13))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's run this with four MPI ranks. Executing `mpiexec -n 2 python kdescent-multigrad-integration.py` yields the following results (about 2x speedup on my laptop):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```txt\n", "Adam Gradient Descent Progress: 100%|██████████| 601/601 [04:06<00:00, 2.44it/s]\n", "Best fit params = [ 9.258453 2.6762085 1.486586 1.1471022 0.6167415 0.43723965\n", " 0.38919714 0.21205594 -0.2331485 1.5793908 0.40174854 1.7807314\n", " 1.7946088 0.56626403 0.58664 0.73132914 0.8944174 -0.34618914\n", " 1.080119 0.9135933 ]\n", "```\n", "![Results plot](./kdescent-multigrad-results.png)" ] } ], "metadata": { "kernelspec": { "display_name": "main311", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }