{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quickstart Tutorial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`kdescent` provides a general framework for comparing an N-dimensional distribution of a model population to that of a training dataset. In short, it draws a mini-batched sample from the training data, computes a kernel density estimate (KDE) of the training distribution, and computes counts weighted by each mini-batched kernel (similar to computing the number density within a randomly drawn bin). These weighted counts can be directly compared to those of the model to compute a loss or likelihood, and even perform gradient descent using Jax's autodiff functionality. To improve the power of gradient descent even further, `kdescent` also provides an analogous metric for comparison of weighted counts in Fourier space. Combining both the KDE and Fourier metrics into a loss term for stochastic gradient descent has shown to be a very powerful method of parameter optimization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import functools\n", "import numpy as np\n", "import jax.numpy as jnp\n", "import jax.random\n", "import matplotlib.pyplot as plt\n", "\n", "from diffopt import kdescent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example model\n", "\n", "A population with variables $x_1$ and $x_2$ distributed as a 2D multivariate normal. We will simply parameterize this distribution with 5 parameters: [$\\mu_1$, $\\mu_2$, $\\sigma_1^2$, $\\sigma_2^2$, $c_{12}$] where $\\mu_i$ is the mean of each variable, $\\sigma_i^2$ is the variance of each variable, and $c_{12}$ is the correlation coefficient between the two variables, which can be translated directly to the off-diagonal element of the covariance matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_nsample = 10_000\n", "model_nsample = 5_000\n", "\n", "# Generate data from a 2D multivariate normal distribution given a\n", "# 5-param model [mean1, mean2, sigma1**2, sigma2**2, correlation_coef]\n", "@functools.partial(jax.jit, static_argnames=[\"nsample\"])\n", "def generate_data(params, randkey, nsample=model_nsample):\n", " mean = params[:2]\n", " cov11, cov22 = jnp.abs(params[2:4])\n", " cov12 = params[4] * jnp.sqrt(cov11 * cov22)\n", " cov = jnp.array([[cov11, cov12],\n", " [cov12, cov22]])\n", " return jax.random.multivariate_normal(\n", " randkey, mean, cov, shape=(nsample,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define \"true\" parameters to generate training data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truth_params = jnp.array([1.6, 2.9, 0.8, 1.25, -0.2])\n", "truth_randkey = jax.random.key(42)\n", "\n", "training_x = generate_data(truth_params, truth_randkey, nsample=data_nsample)\n", "\n", "plt.hexbin(*training_x.T, mincnt=1, norm=plt.matplotlib.colors.LogNorm(),\n", " linewidth=0.3)\n", "plt.text(0.95, 0.95, \"Training data\", fontsize=14,\n", " transform=plt.gca().transAxes, ha=\"right\", va=\"top\")\n", "plt.xlabel(\"$x_1$\", fontsize=14)\n", "plt.ylabel(\"$x_2$\", fontsize=14)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define loss function comparing ${\\rm PDF}(x_1, x_2)$\n", "\n", "- Characterize the loss as the difference between our training and model distributions\n", "- We will evaluate these distributions around randomized kernel centers using the `compare_kde_counts` method (20 kernels by default)\n", "- To perform this gradient descent stochastically, the random seed for drawing kernel centers must be updated at each step. To do this, we will have our loss function accept the `randkey` argument, which it will split and pass to all sources of stochasticity" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ktrain = kdescent.KPretrainer.from_training_data(\n", " training_x, num_eval_kernels=20)\n", "kde = kdescent.KCalc(ktrain)\n", "\n", "\n", "def lossfunc(params, randkey):\n", " # Split random key for (1) multivariate draws and (2) kernel mini-batching\n", " key1, key2 = jax.random.split(randkey)\n", " model_x = generate_data(params, randkey=key1)\n", " model_kde_counts, truth_kde_counts = kde.compare_kde_counts(key2, model_x)\n", " \n", " # Must divide by total number in sample since the training dataset\n", " # is not the same size as the population generated by the model\n", " model_kde_density = model_kde_counts / model_nsample\n", " truth_kde_density = truth_kde_counts / data_nsample\n", "\n", " # Return the mean-squared error of our metrics\n", " return jnp.mean((model_kde_density - truth_kde_density)**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optionally, skip the pretraining next time by writing to disk" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ktrain.save(\"pretrainer.npz\")\n", "ktrain = kdescent.KPretrainer.load(\"pretrainer.npz\")\n", "kde = kdescent.KCalc(ktrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run gradient descent" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define initial guess and bounds for our parameters\n", "guess = jnp.array([0., 0., 1., 1., 0.])\n", "bounds = jnp.array([[-jnp.inf, jnp.inf], [-jnp.inf, jnp.inf],\n", " [0.001, jnp.inf], [0.001, jnp.inf], [-0.999, 0.999]])\n", "\n", "# Run gradient descent to approximately recover the truth\n", "adam_params, adam_losses = kdescent.adam(\n", " lossfunc, guess, nsteps=100, param_bounds=bounds,\n", " learning_rate=1.0, randkey=12345)\n", "print(\"Final params =\", adam_params[-1])\n", "print(\"True params =\", truth_params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(ncols=2, figsize=(8, 3.5))\n", "\n", "axes[0].hexbin(*training_x.T, mincnt=1, gridsize=100)\n", "axes[0].scatter(*generate_data(guess, truth_randkey).T,\n", " s=1, alpha=0.3, color=\"C1\")\n", "axes[0].text(0.02, 0.02, \"Initial guess\", fontsize=13,\n", " color=\"C1\", transform=axes[0].transAxes)\n", "axes[0].set_xlim(-2.5, 4.5)\n", "axes[0].set_ylim(-3, 7)\n", "\n", "axes[1].hexbin(*training_x.T, mincnt=1, gridsize=100)\n", "axes[1].scatter(*generate_data(adam_params[-1], truth_randkey).T,\n", " s=1, alpha=0.3, color=\"C1\")\n", "axes[1].text(0.02, 0.02, f\"Solution\", fontsize=13,\n", " color=\"C1\", transform=axes[1].transAxes)\n", "axes[1].set_xlim(-2.5, 4.5)\n", "axes[1].set_ylim(-3, 7)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Advanced Usage\n", "\n", "## More complex example model\n", "\n", "- 20-parameter model that generates a non-trivial bimodal 3-dimensional distribution (variables: $\\log M_\\star, g-r, r-z$)\n", " - To aid our gradient descent maneuver such a tricky parameter space, we will introduce Fourier-space terms into our loss\n", "- To add even more complexity all at once: the training dataset is undersampled below $\\log M_\\star < 10.5$\n", " - We must therefore rely on *conditional* probability distributions $P(g-r, r-z | \\log M_\\star)$, with a separate `KCalc` object handling each bin of our conditional variable, $\\log M_\\star$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_nsample = 20_000\n", "data_nsample = 10_000 # same volume, but undersampled below logM* < 10.5\n", "\n", "# Generate data weighted from two mass-dependent multivariate normals\n", "@functools.partial(jax.jit, static_argnames=[\"undersample\", \"nsample\"])\n", "def generate_model(params, randkey, undersample=False, nsample=model_nsample):\n", " # Parse all 20 parameters\n", " # =======================\n", " # Lower and upper bounds on log stellar mass\n", " logmlim = params[:2]\n", " logmlim = logmlim.at[1].add(logmlim[0] + 0.001)\n", "\n", " # Distribution parameters at lower mass bound\n", " mean_mmin = params[2:4]\n", " sigma11, sigma22 = params[4:6]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[6] * maxcov\n", " cov_mmin = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmin = params[7]\n", " qmean_mmin = mean_mmin + params[8:10]\n", " qscale_mmin = params[10]\n", "\n", " # Distribution parameters at upper mass bound\n", " mean_mmax = params[11:13]\n", " sigma11, sigma22 = params[13:15]\n", " maxcov = jnp.sqrt(sigma11 * sigma22)\n", " sigma12 = params[15] * maxcov\n", " cov_mmax = jnp.array([[sigma11, sigma12],\n", " [sigma12, sigma22]])\n", " qfrac_mmax = params[16]\n", " qmean_mmax = mean_mmax + params[17:19]\n", " qscale_mmax = params[19]\n", "\n", " # Generate distribution from parameters\n", " # =====================================\n", " key1, key2 = jax.random.split(randkey, num=2)\n", " triangle_vals = (0, 0.5, 1) if undersample else (0, 0, 1)\n", " logm = jax.random.triangular(key1, *triangle_vals, shape=(nsample,))\n", " logm = logmlim[0] + logm * (logmlim[1] - logmlim[0])\n", " # Calculate slope of mass dependence\n", " dlogm = logmlim[1] - logmlim[0]\n", " dmean = (mean_mmax - mean_mmin) / dlogm\n", " dcov = (cov_mmax - cov_mmin) / dlogm\n", " dqfrac = (qfrac_mmax - qfrac_mmin) / dlogm\n", " dqmean = (qmean_mmax - qmean_mmin) / dlogm\n", " dqscale = (qscale_mmax - qscale_mmin) / dlogm\n", " # Apply mass dependence\n", " mean_sf = mean_mmin + dmean * (logm[:, None] - logmlim[0])\n", " cov_sf = cov_mmin + dcov * (logm[:, None, None] - logmlim[0])\n", " mean_q = qmean_mmin + dqmean * (logm[:, None] - logmlim[0])\n", " qscale = qscale_mmin + dqscale * (logm - logmlim[0])\n", " cov_q = cov_sf * qscale[:, None, None] ** 2\n", " qfrac = qfrac_mmin + dqfrac * (logm - logmlim[0])\n", "\n", " # Generate colors from two separate multivariate normals\n", " rz_sf, gr_sf = jax.random.multivariate_normal(key2, mean_sf, cov_sf).T\n", " rz_q, gr_q = jax.random.multivariate_normal(key2, mean_q, cov_q).T\n", " # Concatenate the quenched + star-forming values and assign weights\n", " data_sf = jnp.array([rz_sf, gr_sf, logm]).T\n", " data_q = jnp.array([rz_q, gr_q, logm]).T\n", " data = jnp.concatenate([data_sf, data_q])\n", " weights = jnp.concatenate([1 - qfrac, qfrac])\n", " return data, weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define \"true\" parameters to generate training data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truth_logmmin = 9.0\n", "truth_logmrange = 3.0\n", "\n", "truth_mean_mmin = jnp.array([1.4, 1.1])\n", "truth_var_mmin = jnp.array([0.7, 0.4])\n", "truth_corr_mmin = 0.3\n", "truth_qfrac_mmin = 0.2\n", "truth_qmean_mmin = jnp.array([-0.1, 1.6])\n", "truth_qscale_mmin = 0.3\n", "\n", "truth_mean_mmax = jnp.array([2.0, 1.6])\n", "truth_var_mmax = jnp.array([0.5, 0.5])\n", "truth_corr_mmax = 0.75\n", "truth_qfrac_mmax = 0.95\n", "truth_qmean_mmax = jnp.array([-0.6, 1.2])\n", "truth_qscale_mmax = 1.1\n", "\n", "bounds_truth_logmrange = [0.001, jnp.inf]\n", "bounds_var = ([0.001, jnp.inf], [0.001, jnp.inf])\n", "bounds_corr = [-0.999, 0.999]\n", "bounds_qfrac = [0.0, 1.0]\n", "bounds_qmean_gr = [0.001, jnp.inf]\n", "bounds_qscale = [0.001, jnp.inf]\n", "\n", "truth = jnp.array([\n", " truth_logmmin, truth_logmrange,\n", " *truth_mean_mmin, *truth_var_mmin, truth_corr_mmin, truth_qfrac_mmin,\n", " *truth_qmean_mmin, truth_qscale_mmin,\n", " *truth_mean_mmax, *truth_var_mmax, truth_corr_mmax, truth_qfrac_mmax,\n", " *truth_qmean_mmax, truth_qscale_mmax\n", "])\n", "guess = jnp.array([\n", " 9.25, 2.5, *[0.0, 0.0, 1.0, 1.0, 0.0, 0.5, 0.0, 1.0, 1.0]*2\n", "])\n", "bounds = [\n", " None, bounds_truth_logmrange,\n", " *[None, None, *bounds_var, bounds_corr, bounds_qfrac,\n", " None, bounds_qmean_gr, bounds_qscale]*2\n", "]\n", "\n", "# Generate training data from the truth parameters we just defined\n", "truth_randkey = jax.random.key(43)\n", "training_x_weighted, training_w = generate_model(\n", " truth, truth_randkey, undersample=True, nsample=data_nsample)\n", "\n", "# KDescent allows weighted training data, but to make this more realistic,\n", "# let's generate our actual training data by randomized weighted sampling\n", "training_selection = jax.random.uniform(\n", " jax.random.split(truth_randkey)[0], training_w.shape) < training_w\n", "training_x = training_x_weighted[training_selection]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define plotting function\n", "\n", "- Plot the mass distribution + the color-color distribution in three separate mass bins" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lowmass_cut = [9.0, 9.5]\n", "midmass_cut = [10.25, 10.75]\n", "highmass_cut = [11.5, 12.0]\n", "is_lowmass = ((lowmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < lowmass_cut[1]))\n", "is_midmass = ((midmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < midmass_cut[1]))\n", "is_highmass = ((highmass_cut[0] < training_x_weighted[:, 2])\n", " & (training_x_weighted[:, 2] < highmass_cut[1]))\n", "training_w_lowmass = training_w * is_lowmass\n", "training_w_midmass = training_w * is_midmass\n", "training_w_highmass = training_w * is_highmass\n", "is_noweight_lowmass = (\n", " (lowmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < lowmass_cut[1]))\n", "is_noweight_midmass = (\n", " (midmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < midmass_cut[1]))\n", "is_noweight_highmass = (\n", " (highmass_cut[0] < training_x[:, 2])\n", " & (training_x[:, 2] < highmass_cut[1]))\n", "\n", "\n", "def generate_model_into_mass_bins(params, randkey):\n", " key1 = jax.random.split(randkey)[0]\n", " model_x, model_w = generate_model(params, randkey=key1)\n", " is_low = ((lowmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < lowmass_cut[1]))\n", " is_mid = ((midmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < midmass_cut[1]))\n", " is_high = ((highmass_cut[0] < model_x[:, 2])\n", " & (model_x[:, 2] < highmass_cut[1]))\n", " return (model_x, model_x[is_low], model_x[is_mid], model_x[is_high],\n", " model_w, model_w[is_low], model_w[is_mid], model_w[is_high])\n", "\n", "\n", "def make_sumstat_plot(params, txt=\"\", fig=None, prev_layers=None):\n", " (modall, modlow, modmid, modhigh,\n", " w_all, w_low, w_mid, w_high) = generate_model_into_mass_bins(\n", " params, jax.random.key(13))\n", " if prev_layers is not None:\n", " for layer in prev_layers:\n", " layer.remove()\n", "\n", " fig = plt.figure(figsize=(10, 9)) if fig is None else fig\n", " ax = fig.add_subplot(221) if len(fig.axes) < 4 else fig.axes[0]\n", " ax.hist(training_x_weighted[:, 2], bins=50, color=\"red\",\n", " weights=training_w)\n", " _, bins, hist1 = ax.hist(\n", " modall[:, 2], color=\"grey\", bins=50, alpha=0.9, weights=w_all)\n", " hist2 = ax.hist(modlow[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_low)[-1]\n", " hist3 = ax.hist(modmid[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_mid)[-1]\n", " hist4 = ax.hist(modhigh[:, 2], bins=list(bins), color=\"C0\",\n", " alpha=0.9, weights=w_high)[-1]\n", " ax.set_xlabel(\"$\\\\log M_\\\\ast$\", fontsize=14)\n", " text1 = ax.text(\n", " 0.98, 0.98, \"Training data\", color=\"red\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", " text2 = ax.text(\n", " 0.98, 0.91, txt, color=\"blue\", va=\"top\", ha=\"right\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(222) if len(fig.axes) < 4 else fig.axes[1]\n", " hex1 = ax.hexbin(*modlow[:, :2].T, mincnt=1,\n", " C=w_low, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_lowmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_lowmass][:, 1]},\n", " weights=training_w[is_lowmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text3 = ax.text(\n", " 0.02, 0.02, f\"${lowmass_cut[0]} < \\\\log M_\\\\ast < {lowmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(223, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[2]\n", " hex2 = ax.hexbin(*modmid[:, :2].T, mincnt=1,\n", " C=w_mid, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_midmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_midmass][:, 1]},\n", " weights=training_w[is_midmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text4 = ax.text(\n", " 0.02, 0.02, f\"${midmass_cut[0]} < \\\\log M_\\\\ast < {midmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", "\n", " ax = fig.add_subplot(224, sharex=ax, sharey=ax) if len(\n", " fig.axes) < 4 else fig.axes[3]\n", " hex3 = ax.hexbin(*modhigh[:, :2].T, mincnt=1,\n", " C=w_high, reduce_C_function=np.sum,\n", " norm=plt.matplotlib.colors.LogNorm())\n", " if prev_layers is None:\n", " sns.kdeplot(\n", " {\"$r - z$\": training_x_weighted[is_highmass][:, 0],\n", " \"$g - r$\": training_x_weighted[is_highmass][:, 1]},\n", " weights=training_w[is_highmass],\n", " x=\"$r - z$\", y=\"$g - r$\", color=\"red\", levels=7, ax=ax)\n", " ax.set_xlabel(\"$r - z$\", fontsize=14)\n", " ax.set_ylabel(\"$g - r$\", fontsize=14)\n", " text5 = ax.text(\n", " 0.02, 0.02, f\"${highmass_cut[0]} < \\\\log M_\\\\ast < {highmass_cut[1]}$\",\n", " fontsize=14, transform=ax.transAxes)\n", " ax.set_xlim(-4, 7)\n", " ax.set_ylim(-4, 7)\n", " return [hex1, hex2, hex3, hist1, hist2, hist3, hist4,\n", " text1, text2, text3, text4, text5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "make_sumstat_plot(truth, txt=\"Truth\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define loss function comparing ${\\rm PDF}(g-r, r-z | M_\\ast)$ *and* its Fourier pair" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ktrain_lowmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_lowmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_midmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_midmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "ktrain_highmass = kdescent.KPretrainer.from_training_data(\n", " training_x[is_noweight_highmass, :2],\n", " bandwidth_factor=0.3, fourier_range_factor=3.0,\n", ")\n", "kcalc_lowmass = kdescent.KCalc(ktrain_lowmass)\n", "kcalc_midmass = kdescent.KCalc(ktrain_midmass)\n", "kcalc_highmass = kdescent.KCalc(ktrain_highmass)\n", "\n", "\n", "# Differentiable alternative hard binning in the loss function:\n", "@jax.jit\n", "def soft_tophat(x, low, high, squish=25.0):\n", " \"\"\"Approximately return 1 when `low < x < high`, else return 0\"\"\"\n", " width = (high - low) / squish\n", " left = jax.nn.sigmoid((x - low) / width)\n", " right = jax.nn.sigmoid((high - x) / width)\n", " return left * right\n", "\n", "@jax.jit\n", "def lossfunc(params, randkey):\n", " key1, *keys = jax.random.split(randkey, 7)\n", " model_x, model_w = generate_model(params, randkey=key1)\n", " weight_low = soft_tophat(model_x[:, 2], *lowmass_cut) * model_w\n", " weight_mid = soft_tophat(model_x[:, 2], *midmass_cut) * model_w\n", " weight_high = soft_tophat(model_x[:, 2], *highmass_cut) * model_w\n", "\n", " model_low_counts, truth_low_counts = kcalc_lowmass.compare_kde_counts(\n", " keys[0], model_x[:, :2], weight_low)\n", " model_mid_counts, truth_mid_counts = kcalc_midmass.compare_kde_counts(\n", " keys[1], model_x[:, :2], weight_mid)\n", " model_high_counts, truth_high_counts = kcalc_highmass.compare_kde_counts(\n", " keys[2], model_x[:, :2], weight_high)\n", "\n", " model_low_fcounts, truth_low_fcounts = kcalc_lowmass.compare_fourier_counts(\n", " keys[3], model_x[:, :2], weight_low)\n", " model_mid_fcounts, truth_mid_fcounts = kcalc_midmass.compare_fourier_counts(\n", " keys[4], model_x[:, :2], weight_mid)\n", " model_high_fcounts, truth_high_fcounts = kcalc_highmass.compare_fourier_counts(\n", " keys[5], model_x[:, :2], weight_high)\n", "\n", " # Convert counts to conditional prob: P(krnl | M*) = N(krnl & M*) / N(M*)\n", " model_low_condprob = model_low_counts / (weight_low.sum() + 1e-10)\n", " model_mid_condprob = model_mid_counts / (weight_mid.sum() + 1e-10)\n", " model_high_condprob = model_high_counts / (weight_high.sum() + 1e-10)\n", " truth_low_condprob = truth_low_counts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_condprob = truth_mid_counts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_condprob = truth_high_counts / (\n", " training_w_highmass.sum() + 1e-10)\n", " # Convert Fourier counts to \"conditional\" ECF analogously\n", " model_low_ecf = model_low_fcounts / (weight_low.sum() + 1e-10)\n", " model_mid_ecf = model_mid_fcounts / (weight_mid.sum() + 1e-10)\n", " model_high_ecf = model_high_fcounts / (weight_high.sum() + 1e-10)\n", " truth_low_ecf = truth_low_fcounts / (training_w_lowmass.sum() + 1e-10)\n", " truth_mid_ecf = truth_mid_fcounts / (training_w_midmass.sum() + 1e-10)\n", " truth_high_ecf = truth_high_fcounts / (training_w_highmass.sum() + 1e-10)\n", "\n", " # One constraint on number density at the highest stellar mass bin\n", " volume = 100.0\n", " model_massfunc = jnp.array([weight_high.sum(),]) / volume\n", " truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume\n", "\n", " # Must abs() the Fourier residuals so that the loss is real\n", " sqerrs = jnp.concatenate([\n", " (model_low_condprob - truth_low_condprob)**2,\n", " (model_mid_condprob - truth_mid_condprob)**2,\n", " (model_high_condprob - truth_high_condprob)**2,\n", " jnp.abs(model_low_ecf - truth_low_ecf)**2,\n", " jnp.abs(model_mid_ecf - truth_mid_ecf)**2,\n", " jnp.abs(model_high_ecf - truth_high_ecf)**2,\n", " (model_massfunc - truth_massfunc)**2,\n", " ])\n", "\n", " return jnp.mean(sqerrs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run gradient descent" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adam_params, adam_losses = kdescent.adam(\n", " lossfunc, guess, nsteps=600, param_bounds=bounds,\n", " learning_rate=0.05, randkey=12345)\n", "print(\"Best fit params =\", adam_params[-1])\n", "\n", "plt.semilogy(adam_losses)\n", "plt.xlabel(\"Adam step\", fontsize=14)\n", "plt.ylabel(\"Loss\", fontsize=14)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(20, 9), layout=\"constrained\")\n", "fig.set_facecolor(\"0.05\")\n", "figs = fig.subfigures(1, 2, wspace=0.004)\n", "figs[0].set_facecolor(\"white\")\n", "figs[1].set_facecolor(\"white\")\n", "make_sumstat_plot(\n", " adam_params[0], txt=\"Initial guess\", fig=figs[0])\n", "make_sumstat_plot(\n", " adam_params[-1],\n", " txt=f\"Solution after {len(adam_params)-1} evaluations\", fig=figs[1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is not a perfect fit (although it *would* continue to improve a little if we ran gradient descent even longer), but our goal to qualitatively recover the target distribution was a success! This shows how powerful kdescent's Fourier counts can be as a complementary summary statistic to the PDF counts." ] } ], "metadata": { "kernelspec": { "display_name": "cuda312", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }