{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Upweighting Example: The Halo Mass Function (HMF)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook will loosely follow the example found in the [Advanced Usage](./intro.ipynb#Advanced-Usage) section of the tutorial. However, the stellar mass $M_{\\rm min}$ and $M_{\\rm max}$ parameters will be replaced by $b_{\\rm SHMR}$ and $m_{\\rm SMHR}$. Stellar masses will now be assigned from a fixed dataset of halo masses (such that the HMF forms a triangle distribution from $11 < \\log M_h < 14$) and the stellar mass is assigned following the relation $\\log M_\\star = b_{\\rm SHMR} + m_{\\rm SMHR} \\log M_h$. Therefore, setting the \"true\" parameters of $b_{\\rm SMHR} = -2$ and $m_{\\rm SMHR} = 1$, we recover an identical distribution of stellar masses to the original model (a triangle distribution ranging from $9 < \\log M_\\star < 12$).\n", "\n", "This example model allows us to demonstrate the implementation of \"HMF upweighting\": Our `generate_model()` function will accept an `hmf_upweights` array that will represent the *effective* number of halos represented by each halo in `logmh_table`. Since our model already includes weights correspoding to quenched and star-forming predictions, we can additionally incorporate our HMF upweights by adding the following lines near the end of the function:\n", "\n", "```python\n", "def generate_model(...):\n", " ...\n", " # Propagate hmf_upweights to both the Q and SF predictions\n", " hmf_upweights_duped = jnp.concatenate([hmf_upweights, hmf_upweights])\n", " # Incorporate hmf upweighting into the existing Q-vs-SF weights\n", " weights = weights * hmf_upweights_duped\n", " ...\n", "```\n", "\n", "At the end of this tutorial, we will show that we arrive at essentially identical results with and without upweighting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import functools\n", "import numpy as np\n", "import jax.numpy as jnp\n", "import jax.random\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from diffopt import kdescent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_nsample = 10_000\n", "data_nsample = 5_000 # same volume, but undersampled below logM* < 10.5\n", "\n", "\n", "# Generate a fixed-seed sample of halos \n", "def load_logmh_table(undersample=False, nsample=model_nsample):\n", " triangle_vals = (0, 0.5, 1) if undersample else (0, 0, 1)\n", " logmh = jax.random.triangular(\n", " jax.random.key(12345), *triangle_vals, shape=(nsample,))\n", "\n", " logmh_min, logmh_max = 11.0, 14.0\n", " logmh = logmh_min + logmh * (logmh_max - logmh_min)\n", " return logmh\n", "\n", "\n", "# Generate data weighted from two mass-dependent multivariate normals\n", "@jax.jit\n", "def generate_model(params, randkey, logmh_table, hmf_upweights=1.0):\n", " # Parse all 20 parameters\n", " # =======================\n", " # Intercept & slope of our linear stellar-to-halo mass relation\n", " b_smhr, m_smhr = params[:2]\n", "\n", " # Distribution parameters at lower mass bound\n", " mean_mmin = params[2:4]\n", " sigma11, sigma22 = params[4:6]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[6] * maxcov\n", " cov_mmin = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmin = params[7]\n", " qmean_mmin = mean_mmin + params[8:10]\n", " qscale_mmin = params[10]\n", "\n", " # Distribution parameters at upper mass bound\n", " mean_mmax = params[11:13]\n", " sigma11, sigma22 = params[13:15]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[15] * maxcov\n", " cov_mmax = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmax = params[16]\n", " qmean_mmax = mean_mmax + params[17:19]\n", " qscale_mmax = params[19]\n", "\n", " # Generate distribution from parameters\n", " # =====================================\n", " logm = b_smhr + m_smhr * logmh_table\n", " # Calculate slope of mass dependence\n", " logmlim = logm.min(), logm.max()\n", " dlogm = logmlim[1] - logmlim[0]\n", " dmean = (mean_mmax - mean_mmin) / dlogm\n", " dcov = (cov_mmax - cov_mmin) / dlogm\n", " dqfrac = (qfrac_mmax - qfrac_mmin) / dlogm\n", " dqmean = (qmean_mmax - qmean_mmin) / dlogm\n", " dqscale = (qscale_mmax - qscale_mmin) / dlogm\n", " # Apply mass dependence\n", " mean_sf = mean_mmin + dmean * (logm[:, None] - logmlim[0])\n", " cov_sf = cov_mmin + dcov * (logm[:, None, None] - logmlim[0])\n", " mean_q = qmean_mmin + dqmean * (logm[:, None] - logmlim[0])\n", " qscale = qscale_mmin + dqscale * (logm - logmlim[0])\n", " cov_q = cov_sf * qscale[:, None, None] ** 2\n", " qfrac = qfrac_mmin + dqfrac * (logm - logmlim[0])\n", "\n", " # Generate colors from two separate multivariate normals\n", " rz_sf, gr_sf = jax.random.multivariate_normal(randkey, mean_sf, cov_sf).T\n", " rz_q, gr_q = jax.random.multivariate_normal(randkey, mean_q, cov_q).T\n", " # Concatenate the quenched + star-forming values and assign weights\n", " data_sf = jnp.array([rz_sf, gr_sf, logm]).T\n", " data_q = jnp.array([rz_q, gr_q, logm]).T\n", " data = jnp.concatenate([data_sf, data_q])\n", " weights = jnp.concatenate([1 - qfrac, qfrac])\n", "\n", " # In case hmf_upweights is scalar, broadcast it to an array\n", " hmf_upweights = jnp.broadcast_to(hmf_upweights, logmh_table.shape)\n", " # Propagate hmf_upweights to both the Q and SF predictions\n", " hmf_upweights_duped = jnp.concatenate([hmf_upweights, hmf_upweights])\n", " # Incorporate hmf upweighting into the existing Q-vs-SF weights\n", " weights = weights * hmf_upweights_duped\n", " return data, weights\n", "\n", "true_logmh_table = load_logmh_table()\n", "undersampled_logmh_table = load_logmh_table(\n", " undersample=True, nsample=data_nsample)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define \"true\" parameters to generate training data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truth_b_smhr = -2.0\n", "truth_m_smhr = 1.0\n", "\n", "truth_mean_mmin = jnp.array([1.4, 1.1])\n", "truth_var_mmin = jnp.array([0.7, 0.4])\n", "truth_corr_mmin = 0.3\n", "truth_qfrac_mmin = 0.2\n", "truth_qmean_mmin = jnp.array([-0.1, 1.6])\n", "truth_qscale_mmin = 0.3\n", "\n", "truth_mean_mmax = jnp.array([2.0, 1.6])\n", "truth_var_mmax = jnp.array([0.5, 0.5])\n", "truth_corr_mmax = 0.75\n", "truth_qfrac_mmax = 0.95\n", "truth_qmean_mmax = jnp.array([-0.6, 1.2])\n", "truth_qscale_mmax = 1.1\n", "\n", "bounds_var = ([0.001, jnp.inf], [0.001, jnp.inf])\n", "bounds_corr = [-0.999, 0.999]\n", "bounds_qfrac = [0.0, 1.0]\n", "bounds_qmean_gr = [0.001, jnp.inf]\n", "bounds_qscale = [0.001, jnp.inf]\n", "\n", "truth = jnp.array([\n", " truth_b_smhr, truth_m_smhr,\n", " *truth_mean_mmin, *truth_var_mmin, truth_corr_mmin, truth_qfrac_mmin,\n", " *truth_qmean_mmin, truth_qscale_mmin,\n", " *truth_mean_mmax, *truth_var_mmax, truth_corr_mmax, truth_qfrac_mmax,\n", " *truth_qmean_mmax, truth_qscale_mmax\n", "])\n", "guess = jnp.array([\n", " -3.0, 1.1, *[0.0, 0.0, 1.0, 1.0, 0.0, 0.5, 0.0, 1.0, 1.0]*2\n", "])\n", "bounds = [\n", " None, None,\n", " *[None, None, *bounds_var, bounds_corr, bounds_qfrac,\n", " None, bounds_qmean_gr, bounds_qscale]*2\n", "]\n", "\n", "# Generate training data from the truth parameters we just defined\n", "truth_randkey = jax.random.key(43)\n", "training_x_weighted, training_w = generate_model(\n", " truth, truth_randkey, undersampled_logmh_table)\n", "\n", "# KDescent allows weighted training data, but to make this more realistic,\n", "# let's generate our actual training data by randomized weighted sampling\n", "training_selection = jax.random.uniform(\n", " jax.random.split(truth_randkey)[0], training_w.shape) < training_w\n", "training_x = training_x_weighted[training_selection]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define plotting function\n", "\n", "- Plot the mass distribution + the color-color distribution in three separate mass bins" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lowmass_cut = [9.0, 9.5]\n", "midmass_cut = [10.25, 10.75]\n", "highmass_cut = [11.5, 12.0]\n", "is_lowmass = ((lowmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < lowmass_cut[1]))\n", "is_midmass = ((midmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < midmass_cut[1]))\n", "is_highmass = ((highmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < highmass_cut[1]))\n", "training_w_lowmass = training_w * is_lowmass\n", "training_w_midmass = training_w * is_midmass\n", "training_w_highmass = training_w * is_highmass\n", "is_noweight_lowmass = (\n", " (lowmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < lowmass_cut[1]))\n", "is_noweight_midmass = (\n", " (midmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < midmass_cut[1]))\n", "is_noweight_highmass = (\n", " (highmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < highmass_cut[1]))\n", "\n", "\n", "def generate_model_into_mass_bins(params, randkey):\n", " key1 = jax.random.split(randkey)[0]\n", " model_x, model_w = generate_model(params, key1, true_logmh_table)\n", " is_low = ((lowmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < lowmass_cut[1]))\n", " is_mid = ((midmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < midmass_cut[1]))\n", " is_high = ((highmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < highmass_cut[1]))\n", " return (model_x, model_x[is_low], model_x[is_mid], model_x[is_high],\n", " model_w, model_w[is_low], model_w[is_mid], model_w[is_high])\n", "\n", "\n", "def make_sumstat_plot(params, txt=\"\", fig=None, prev_layers=None):\n", " (modall, modlow, modmid, modhigh,\n", " w_all, w_low, w_mid, w_high) = generate_model_into_mass_bins(\n", " params, jax.random.key(13))\n", " if prev_layers is not None:\n", " for layer in prev_layers:\n", " layer.remove()\n", "\n", " fig = plt.figure(figsize=(10, 9)) if fig is None else fig\n", " ax = fig.add_subplot(221) if len(fig.axes) < 4 else fig.axes[0]\n", " ax.hist(training_x_weighted[:, 2], bins=50, color=\"red\",\n", " weights=training_w)\n", " _, bins, hist1 = ax.hist(\n", " modall[:, 2], color=\"grey\", bins=50, alpha=0.9, weights=w_all)\n", " hist2 = ax.hist(modlow[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_low)[-1]\n", " hist3 = ax.hist(modmid[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_mid)[-1]\n", " hist4 = ax.hist(modhigh[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_high)[-1]\n", " ax.set_xlabel(\"$\\\\log M_\\\\ast$\", fontsize=14)\n", " text1 = ax.text(\n", " 0.98, 0.98, \"Training data\", color=\"red\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", " text2 = ax.text(\n", " 0.98, 0.91, txt, color=\"blue\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(222) if len(fig.axes) < 4 else fig.axes[1]\n", " hex1 = ax.hexbin(*modlow[:, :2].T, mincnt=1,\n", " C=w_low, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_lowmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_lowmass][:, 1]},\n", " weights=training_w[is_lowmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text3 = ax.text(\n", " 0.02, 0.02, f\"${lowmass_cut[0]} < \\\\log M_\\\\ast < {lowmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(223, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[2]\n", " hex2 = ax.hexbin(*modmid[:, :2].T, mincnt=1,\n", " C=w_mid, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_midmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_midmass][:, 1]},\n", " weights=training_w[is_midmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text4 = ax.text(\n", " 0.02, 0.02, f\"${midmass_cut[0]} < \\\\log M_\\\\ast < {midmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(224, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[3]\n", " hex3 = ax.hexbin(*modhigh[:, :2].T, mincnt=1,\n", " C=w_high, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_highmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_highmass][:, 1]},\n", " weights=training_w[is_highmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text5 = ax.text(\n", " 0.02, 0.02, f\"${highmass_cut[0]} < \\\\log M_\\\\ast < {highmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", " ax.set_xlim(-4, 7)\n", " ax.set_ylim(-4, 7)\n", " return [hex1, hex2, hex3, hist1, hist2, hist3, hist4,\n", " text1, text2, text3, text4, text5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "make_sumstat_plot(truth, txt=\"Truth\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define loss function comparing ${\\rm PDF}(g-r, r-z | M_\\ast)$ *and* its Fourier pair" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ktrain_lowmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_lowmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_midmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_midmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_highmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_highmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "kcalc_lowmass = kdescent.KCalc(ktrain_lowmass)\n", "kcalc_midmass = kdescent.KCalc(ktrain_midmass)\n", "kcalc_highmass = kdescent.KCalc(ktrain_highmass)\n", "\n", "\n", "# Differentiable alternative hard binning in the loss function:\n", "@jax.jit\n", "def soft_tophat(x, low, high, squish=25.0):\n", " \"\"\"Approximately return 1 when `low < x < high`, else return 0\"\"\"\n", " width = (high - low) / squish\n", " left = jax.nn.sigmoid((x - low) / width)\n", " right = jax.nn.sigmoid((high - x) / width)\n", " return left * right\n", "\n", "\n", "@jax.jit\n", "def lossfunc(params, randkey, logmh_table=None, hmf_upweights=1.0):\n", " if logmh_table is None:\n", " logmh_table = true_logmh_table\n", " key1, *keys = jax.random.split(randkey, 7)\n", " model_x, model_w = generate_model(params, key1, logmh_table, hmf_upweights)\n", " weight_low = soft_tophat(model_x[:, 2], *lowmass_cut) * model_w\n", " weight_mid = soft_tophat(model_x[:, 2], *midmass_cut) * model_w\n", " weight_high = soft_tophat(model_x[:, 2], *highmass_cut) * model_w\n", "\n", " model_low_counts, truth_low_counts = kcalc_lowmass.compare_kde_counts(\n", " keys[0], model_x[:, :2], weight_low)\n", " model_mid_counts, truth_mid_counts = kcalc_midmass.compare_kde_counts(\n", " keys[1], model_x[:, :2], weight_mid)\n", " model_high_counts, truth_high_counts = kcalc_highmass.compare_kde_counts(\n", " keys[2], model_x[:, :2], weight_high)\n", "\n", " model_low_fcounts, truth_low_fcounts = kcalc_lowmass.compare_fourier_counts(\n", " keys[3], model_x[:, :2], weight_low)\n", " model_mid_fcounts, truth_mid_fcounts = kcalc_midmass.compare_fourier_counts(\n", " keys[4], model_x[:, :2], weight_mid)\n", " model_high_fcounts, truth_high_fcounts = kcalc_highmass.compare_fourier_counts(\n", " keys[5], model_x[:, :2], weight_high)\n", "\n", " # Convert counts to conditional prob: P(krnl | M*) = N(krnl & M*) / N(M*)\n", " model_low_condprob = model_low_counts / (weight_low.sum() + 1e-10)\n", " model_mid_condprob = model_mid_counts / (weight_mid.sum() + 1e-10)\n", " model_high_condprob = model_high_counts / (weight_high.sum() + 1e-10)\n", " truth_low_condprob = truth_low_counts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_condprob = truth_mid_counts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_condprob = truth_high_counts / (\n", " training_w_highmass.sum() + 1e-10)\n", " # Convert Fourier counts to \"conditional\" ECF analogously\n", " model_low_ecf = model_low_fcounts / (weight_low.sum() + 1e-10)\n", " model_mid_ecf = model_mid_fcounts / (weight_mid.sum() + 1e-10)\n", " model_high_ecf = model_high_fcounts / (weight_high.sum() + 1e-10)\n", " truth_low_ecf = truth_low_fcounts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_ecf = truth_mid_fcounts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_ecf = truth_high_fcounts / (training_w_highmass.sum() + 1e-10)\n", "\n", " # One constraint on number density at the highest stellar mass bin\n", " volume = 100.0\n", " model_massfunc = jnp.array([weight_high.sum(),]) / volume\n", " truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume\n", "\n", " # Must abs() the fourier-difference so the loss is real\n", " sqerrs = jnp.concatenate([(model_low_condprob - truth_low_condprob)**2,\n", " (model_mid_condprob - truth_mid_condprob)**2,\n", " (model_high_condprob - truth_high_condprob)**2,\n", " jnp.abs(model_low_ecf - truth_low_ecf)**2,\n", " jnp.abs(model_mid_ecf - truth_mid_ecf)**2,\n", " jnp.abs(model_high_ecf - truth_high_ecf)**2,\n", " (model_massfunc - truth_massfunc)**2,\n", " ])\n", "\n", " return jnp.mean(sqerrs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Descend the gradient *without* upweighting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adam_params, adam_losses = kdescent.adam(\n", " lossfunc, guess, nsteps=600, param_bounds=bounds,\n", " learning_rate=0.03, randkey=13)\n", "print(\"Best fit params =\", adam_params[-1])\n", "\n", "plt.semilogy(adam_losses)\n", "plt.xlabel(\"Adam step\", fontsize=14)\n", "plt.ylabel(\"Loss\", fontsize=14)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(20, 9), layout=\"constrained\")\n", "fig.set_facecolor(\"0.05\")\n", "figs = fig.subfigures(1, 2, wspace=0.004)\n", "figs[0].set_facecolor(\"white\")\n", "figs[1].set_facecolor(\"white\")\n", "make_sumstat_plot(\n", " adam_params[0], txt=\"Initial guess\", fig=figs[0])\n", "make_sumstat_plot(\n", " adam_params[500],\n", " txt=f\"Solution after {len(adam_params)-1} evaluations\", fig=figs[1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Descend the gradient *with* upweighting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define a much smaller grid of halo masses, evenly distributed so we aren't\n", "# wasting all our computation on the many low-mass halos\n", "evenly_spaced_logmh_table = np.linspace(11, 14, num=500)\n", "\n", "# Define HMF upweighting using relative histogram counts - alternatively,\n", "# you could simply use an idealized functional form for the HMF\n", "hmf_bins = np.linspace(10.999, 14.001, num=50)\n", "evenly_spaced_hist_counts = np.histogram(\n", " evenly_spaced_logmh_table, hmf_bins)[0]\n", "true_hist_counts = np.histogram(\n", " true_logmh_table, hmf_bins)[0]\n", "binned_hmf_upweights = true_hist_counts / evenly_spaced_hist_counts\n", "\n", "# Assign HMF upweights based on the bin each \"halo\" falls in\n", "bin_inds = np.digitize(evenly_spaced_logmh_table, hmf_bins) - 1\n", "hmf_upweights = binned_hmf_upweights[bin_inds]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot HMF upweighting vs. halo mass\n", "plt.plot(evenly_spaced_logmh_table, hmf_upweights, label=\"HMF Upweight factor\")\n", "plt.hist(evenly_spaced_logmh_table, bins=hmf_bins, label=\"Halo mass histogram\")\n", "plt.xlabel(\"$\\\\log M_h$\", fontsize=14)\n", "plt.legend(frameon=False)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Specify new halo table and HMF upweights in our new loss function\n", "upweighted_lossfunc = functools.partial(\n", " lossfunc, logmh_table=evenly_spaced_logmh_table,\n", " hmf_upweights=hmf_upweights)\n", "\n", "# Run gradient descent just like before (BUT ~5x FASTER!)\n", "upweighted_adam_params, upweighted_adam_losses = kdescent.adam(\n", " upweighted_lossfunc, guess, nsteps=600, param_bounds=bounds,\n", " learning_rate=0.03, randkey=13)\n", "print(\"Best fit params =\", upweighted_adam_params[-1])\n", "\n", "plt.semilogy(adam_losses)\n", "plt.xlabel(\"Adam step\", fontsize=14)\n", "plt.ylabel(\"Loss\", fontsize=14)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(20, 9), layout=\"constrained\")\n", "fig.set_facecolor(\"0.05\")\n", "figs = fig.subfigures(1, 2, wspace=0.004)\n", "figs[0].set_facecolor(\"white\")\n", "figs[1].set_facecolor(\"white\")\n", "make_sumstat_plot(\n", " upweighted_adam_params[0], txt=\"Initial guess\", fig=figs[0])\n", "make_sumstat_plot(\n", " upweighted_adam_params[-1],\n", " txt=f\"Solution after {len(adam_params)-1} evaluations\", fig=figs[1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Closing Remarks\n", "\n", "Neither of the fits shown here are perfect (and not even fully converged for that matter), but they are both able to qualitatively reproduce distributions that closely resemble that of the training data by eye. The power of upweighting is that we can get away with reducing compution by lowering the amount of data coming from certain regions of feature space that are over-represented, such as low halo mass bins. This allowed us to go from using 10,000 halos down to only 500 halos with HMF upweighting. This 20x reduction in data led to about a 5x reduction of compute time and memory with very similar results!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }