Source code for diffopt.kdescent.kstats

import dataclasses
from functools import partial
from typing import Any, Literal, Tuple, overload

import jax.numpy as jnp
import jax.random
import numpy as np


[docs] @dataclasses.dataclass class KPretrainer: """ Stores precomputed kernel and Fourier counts for training data, with kernel centers sampled from the training data PDF (via gaussian_kde). Provides save/load functionality. """ kernel_centers: np.ndarray fourier_positions: np.ndarray kernelcov: np.ndarray kde_counts: np.ndarray kde_err: np.ndarray fourier_counts: np.ndarray fourier_err: np.ndarray num_eval_kernels: int num_eval_fourier_positions: int num_pretrain_kernels: int num_pretrain_fourier_positions: int bandwidth_factor: float fourier_range_factor: float covariant_kernels: bool inverse_density_weight_power: float training_sum_of_weights: float seed: int
[docs] @classmethod def from_training_data( cls, training_x, training_weights=None, num_eval_kernels=None, num_eval_fourier_positions=None, num_pretrain_kernels=None, num_pretrain_fourier_positions=None, bandwidth_factor=1.0, fourier_range_factor=1.0, covariant_kernels=False, inverse_density_weight_power=0.0, num_idw_draws=None, chunk_size=None, seed=0, comm=None, ): """ Create a pre-trained KPretrainer object from training data. Parameters ---------- training_x : array-like Training data of shape (n_data, n_features) training_weights : array-like, optional Training weights of shape (n_data,), by default None num_eval_kernels : int, optional Number of KDE kernels to appriximate the PDF, by default 10*ndim num_eval_fourier_positions : int, optional Number of points to evaluate the ECF, by default 10*ndim num_pretrain_kernels : int, optional Number of KDE kernels to precompute training data PDF, by default 300*num_eval_kernels num_pretrain_fourier_positions : int, optional Number of points to precompute training data ECF, by default 300*num_eval_fourier_positions bandwidth_factor : float, optional Increase or decrease the kernel bandwidth, by default 1.0 fourier_range_factor : float, optional Increase or decrease the Fourier search space, by default 1.0 covariant_kernels : bool, optional If True, kernels will align with the principle components of the training data, which can blow up kernel count values if cov matrix has near-zero eigenvalues. By default False inverse_density_weight_power : float, optional At 1.0, this will weight the kernel selection by the inverse density of the training data. This is useful for selecting kernels in low-density regions. No selection weighting by default num_idw_draws : int, optional Number of KDE draws + evaluations in total for the importance resampling to determine kernel selection with inverse density weighting. By default 100*num_pretrain_kernels chunk_size : int, optional Chunk size for pre-computation of training KDE counts, to prevent memory overflow. If None, chunk_size will default to `max(num_eval_kernels, num_eval_fourier_positions)` seed : int, optional Random seed for reproducibility, by default 0 comm : MPI Communicator, optional Distribute pre-computation of training kernel counts across ranks, assuming full training data is loaded and identical across ranks. """ seed = int(seed) randkeys = jax.random.split(jax.random.key(seed + 987), 3) training_x = jnp.atleast_2d(jnp.asarray(training_x).T).T assert training_x.ndim == 2, "x must have shape (ndata, ndim)" ndim = training_x.shape[1] # By default, use 10 * ndim evaluation kerels and fourier positions if num_eval_kernels is None: num_eval_kernels = 10 * ndim if num_eval_fourier_positions is None: num_eval_fourier_positions = 10 * ndim num_eval_kernels = int(num_eval_kernels) num_eval_fourier_positions = int(num_eval_fourier_positions) # By default, pretrain on 300 * number of evaluation kernels if num_pretrain_kernels is None: num_pretrain_kernels = 300 * num_eval_kernels if num_pretrain_fourier_positions is None: num_pretrain_fourier_positions = 300 * num_eval_fourier_positions if chunk_size is None: chunk_size = max( num_eval_kernels, num_eval_fourier_positions, 1) num_pretrain_kernels = int(num_pretrain_kernels) num_pretrain_fourier_positions = int(num_pretrain_fourier_positions) chunk_size = int(chunk_size) # Bandwidth and kernel covariance bandwidth = _set_bandwidth( num_eval_kernels, ndim, bandwidth_factor) kernelcov = _bandwidth_to_kernelcov( training_x, bandwidth, training_weights, covariant_kernels ) k_max = (fourier_range_factor / training_x.std(ddof=1, axis=0)) # KDE for sampling kernel centers kde = jax.scipy.stats.gaussian_kde( training_x.T, weights=training_weights) # Importance resampling for inverse density weighting if inverse_density_weight_power > 0: if num_idw_draws is None: num_idw_draws = 100 * num_pretrain_kernels idw_chunk_size = chunk_size * 100 num_chunks = num_idw_draws // idw_chunk_size + ( num_idw_draws % idw_chunk_size > 0) # Might as well distribute chunks across MPI ranks if comm is not None: chunk_inds = np.array_split( np.arange(num_chunks), comm.size)[comm.rank] else: chunk_inds = range(num_chunks) pdf_vals = [] draw_keys = jax.random.split(randkeys[2], num_chunks) draw_raw_samples = jax.jit( lambda x: kde.resample(x, (idw_chunk_size,))) compute_pdf_vals = jax.jit(lambda x: kde.pdf(x)) for i in chunk_inds: raw_samples = draw_raw_samples(draw_keys[i]) pdf_vals.append(compute_pdf_vals(raw_samples)) idw = jnp.concatenate(pdf_vals) ** (-inverse_density_weight_power) if comm is not None: idw = jnp.concatenate(comm.allgather(idw)) # Choose kernel centers with importance weights chosen_idx = jax.random.choice( randkeys[0], num_idw_draws, (num_pretrain_kernels,), p=idw[:num_idw_draws], replace=False) kernel_centers = jnp.asarray(raw_samples[:, chosen_idx].T) else: kernel_centers = jax.jit( lambda x: kde.resample(x, (num_pretrain_kernels,)))( randkeys[2]).T # Sample fourier positions uniformly in k-space fourier_positions = jax.random.uniform( randkeys[1], (num_pretrain_fourier_positions, ndim) ) * k_max[None, :] if comm is not None: # Distribute kernel centers and fourier positions across ranks kernel_centers = np.array_split( kernel_centers, comm.size)[comm.rank] fourier_positions = np.array_split( fourier_positions, comm.size)[comm.rank] # Precompute KDE and Fourier counts for training data chunk_inds = list(range( chunk_size, len(kernel_centers), chunk_size)) kde_counts, kde_err = np.concatenate([_predict_kde_counts( training_x, training_weights, x, kernelcov, return_err=True ) for x in np.array_split(kernel_centers, chunk_inds)], axis=1) chunk_inds = list(range( chunk_size, len(fourier_positions), chunk_size)) fourier_counts, fourier_err = np.concatenate([_predict_fourier( training_x, training_weights, x, return_err=True ) for x in np.array_split(fourier_positions, chunk_inds)], axis=1) kernel_centers = np.asarray(kernel_centers) fourier_positions = np.asarray(fourier_positions) kde_counts = np.asarray(kde_counts) kde_err = np.asarray(kde_err) fourier_counts = np.asarray(fourier_counts) fourier_err = np.asarray(fourier_err) if comm is not None: # Gather all precomputed counts across all ranks kde_counts = np.concatenate(comm.allgather(kde_counts)) kde_err = np.concatenate(comm.allgather(kde_err)) fourier_counts = np.concatenate(comm.allgather(fourier_counts)) fourier_err = np.concatenate(comm.allgather(fourier_err)) kernel_centers = np.concatenate(comm.allgather(kernel_centers)) fourier_positions = np.concatenate( comm.allgather(fourier_positions)) training_sum_of_weights = len(training_x) if training_weights is not None: training_sum_of_weights = training_weights.sum() return cls( kernel_centers=kernel_centers, fourier_positions=fourier_positions, kernelcov=kernelcov, kde_counts=kde_counts, kde_err=kde_err, fourier_counts=fourier_counts, fourier_err=fourier_err, num_eval_kernels=num_eval_kernels, num_eval_fourier_positions=num_eval_fourier_positions, num_pretrain_kernels=num_pretrain_kernels, num_pretrain_fourier_positions=num_pretrain_fourier_positions, bandwidth_factor=bandwidth_factor, fourier_range_factor=fourier_range_factor, covariant_kernels=covariant_kernels, inverse_density_weight_power=inverse_density_weight_power, training_sum_of_weights=training_sum_of_weights, seed=seed )
[docs] def save(self, filename): """Save the pre-trained object to disk as a .npz numpy zip file""" data = {field.name: getattr(self, field.name) for field in dataclasses.fields(self)} np.savez(filename, **data)
[docs] @classmethod def load(cls, filename): """Load a pre-trained object from disk .npz file""" with np.load(filename, allow_pickle=False) as data: kwargs = {key: data[key] for key in data.files} return cls(**kwargs)
[docs] class KCalc:
[docs] def __init__(self, pretrainer): """ This KDE object is the fundamental building block of kdescent. It can be used to compare randomized evaluations of the PDF and ECF by training data to model predictions. Parameters ---------- pretrainer : KPretrainer A pre-trained KPretrainer object that precomputes possible kernel centers and their associated training data counts. """ if not isinstance(pretrainer, KPretrainer): raise TypeError("pretrainer must be an instance of KPretrainer") self.kernel_centers = jnp.array(pretrainer.kernel_centers) self.fourier_positions = jnp.array(pretrainer.fourier_positions) self.kernelcov = jnp.array(pretrainer.kernelcov) self.kde_counts = jnp.array(pretrainer.kde_counts) self.kde_err = jnp.array(pretrainer.kde_err) self.fourier_counts = jnp.array(pretrainer.fourier_counts) self.fourier_err = jnp.array(pretrainer.fourier_err) self.num_eval_kernels = int(pretrainer.num_eval_kernels) self.num_eval_fourier_positions = int( pretrainer.num_eval_fourier_positions) self.num_pretrain_kernels = int(pretrainer.num_pretrain_kernels) self.num_pretrain_fourier_positions = int( pretrainer.num_pretrain_fourier_positions) self.bandwidth_factor = float(pretrainer.bandwidth_factor) self.fourier_range_factor = float(pretrainer.fourier_range_factor) self.covariant_kernels = bool(pretrainer.covariant_kernels) self.inverse_density_weight_power = float( pretrainer.inverse_density_weight_power) self.training_sum_of_weights = float( pretrainer.training_sum_of_weights)
def reduced_chisq_loss(self, randkey, x, weights=None, density=False): key1, key2 = jax.random.split(randkey, 2) model_k, truth_k, err_k = self.compare_kde_counts( key1, x, weights=weights, return_err=True) model_f, truth_f, err_f = self.compare_fourier_counts( key2, x, weights=weights, return_err=True) # Remove dependence of overall normalization if density=True if density: model_n = len(x) if weights is not None: model_n = weights.sum() model_k *= self.training_sum_of_weights / model_n model_f *= self.training_sum_of_weights / model_n normalized_residuals = jnp.concatenate([ (model_k - truth_k) / err_k, (model_f.real - truth_f.real) / err_f.real, (model_f.imag - truth_f.imag) / err_f.imag ]) return jnp.mean(normalized_residuals**2) # Specify signatures to make linters happy @overload def compare_kde_counts( self, randkey: Any, x: Any, weights: Any = None, return_err: Literal[False] = False, comm: Any = None ) -> Tuple[jnp.ndarray, jnp.ndarray]: ... @overload def compare_kde_counts( self, randkey: Any, x: Any, weights: Any = None, return_err: Literal[True] = True, comm: Any = None ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ... @overload def compare_fourier_counts( self, randkey: Any, x: Any, weights: Any = None, return_err: Literal[False] = False, comm: Any = None ) -> Tuple[jnp.ndarray, jnp.ndarray]: ... @overload def compare_fourier_counts( self, randkey: Any, x: Any, weights: Any = None, return_err: Literal[True] = True, comm: Any = None ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ...
[docs] def compare_kde_counts(self, randkey, x, weights=None, return_err=False, comm=None): """ Realize kernel centers and return all kernel-weighted counts Parameters ---------- x : array-like Model data of shape (n_model_data, n_features) weights : array-like, optional Effective counts with shape (n_model_data,). If supplied, function will return sum(weights * kernel_weights) within each kernel instead of simply sum(kernel_weights) return_err: bool If true, also return the uncertainty of all training KDE counts values according to the effective sample size (ESS) in each kernel comm : MPI Communicator, optional For parallel computing, this guarantees consistent kernel placements by all MPI ranks within the comm, by default None. WARNING: Do not pass in an MPI communicator here if you plan on JIT compiling; just pass identical randkeys for each MPI rank Returns ------- prediction : jnp.ndarray KDE counts measured on `x`. Has shape (num_kernels,) truth : jnp.ndarray KDE counts measured on `training_x`. This is always different due to the random kernel placements. Has shape (num_kernels,) err : jnp.ndarray Returned if return_err=True, uncertainties of each KDE count in `truth` equal to truth/sqrt(ESS) """ kernel_inds = self.realize_kernel_inds(randkey, comm) kernel_cens = self.kernel_centers[kernel_inds] prediction = _predict_kde_counts( x, weights, kernel_cens, self.kernelcov, return_err=False) truth = self.kde_counts[kernel_inds] if return_err: err = self.kde_err[kernel_inds] return prediction, truth, err else: return prediction, truth
[docs] def compare_fourier_counts(self, randkey, x, weights=None, return_err=False, comm=None): """ Return randomly-placed evaluations of the ECF (Empirical Characteristic Function = Fourier-transformed PDF) Parameters ---------- x : array-like Model data of shape (n_model_data, n_features) weights : array-like, optional Effective counts with shape (n_model_data,). If supplied, the ECF will be weighted as sum(weights * exp^(...)) at each evaluation in k-space instead of simply sum(exp^(...)) return_err: bool If true, also return the uncertainty of all training Fourier counts values according to the effective sample size (ESS) in each kernel comm : MPI Communicator, optional For parallel computing, this guarantees consistent kernel placements by all MPI ranks within the comm, by default None. WARNING: Do not pass in an MPI communicator here if you plan on JIT compiling; just pass identical randkeys for each MPI rank Returns ------- prediction : jnp.ndarray (complex-valued) CF evaluations measured on `x`. Has shape (num_kernels,) truth : jnp.ndarray (complex-valued) CF evaluations measured on `training_x`. This is always different due to the random evaluation kernels. Has shape (num_kernels,) err : jnp.ndarray Returned if return_err=True, uncertainties of each Fourier count in `truth` equal to truth/sqrt(ESS) """ fourier_inds = self.realize_fourier_inds(randkey, comm) fourier_positions = self.fourier_positions[fourier_inds] prediction = _predict_fourier( x, weights, fourier_positions, return_err=False) truth = self.fourier_counts[fourier_inds] if return_err: err = self.fourier_err[fourier_inds] return prediction, truth, err else: return prediction, truth
def realize_kernel_inds(self, randkey, comm=None): if comm is None: return _sample_kernel_inds( self.num_eval_kernels, self.num_pretrain_kernels, randkey) else: kernel_inds = [] if not comm.rank: kernel_inds = _sample_kernel_inds( self.num_eval_kernels, self.num_pretrain_kernels, randkey) return comm.bcast(kernel_inds, root=0) def realize_fourier_inds(self, randkey, comm=None): if comm is None or comm.rank == 0: fourier_inds = _sample_kernel_inds( self.num_eval_fourier_positions, self.num_pretrain_fourier_positions, randkey) if comm is not None: comm.bcast(fourier_inds, root=0) else: fourier_inds = comm.bcast([], root=0) return fourier_inds
@jax.jit def _set_bandwidth(n, d, bandwidth_factor): return n ** (-1.0 / (d + 4)) * bandwidth_factor @partial(jax.jit, static_argnums=[3]) def _bandwidth_to_kernelcov(training_x, bandwidth, weights=None, covariant_kernels=True): empirical_cov = jnp.cov(training_x, rowvar=False, aweights=weights) if not covariant_kernels: empirical_cov = jnp.diag(jnp.diag(empirical_cov)) return empirical_cov * bandwidth**2 @partial(jax.jit, static_argnums=[0, 1]) def _sample_kernel_inds(num_samples, num_kernels, randkey): inds = jax.random.choice( randkey, num_kernels, (num_samples,), p=None) return inds @jax.jit def _weights_in_kernel(x, kernel_cen, cov): return jax.scipy.stats.multivariate_normal.pdf( x, mean=kernel_cen, cov=cov) _vmap_weights_in_kernel = jax.jit(jax.vmap( _weights_in_kernel, in_axes=(None, 0, None))) @jax.jit def _get_kernel_probs(x, kernel_cens, cov): # ind_weights = [_weights_in_kernel(x, training_x, cov, ind) # for ind in kernel_inds] ind_weights = _vmap_weights_in_kernel(x, kernel_cens, cov) return jnp.asarray(ind_weights) @jax.jit def _get_fourier_exponentials(x, fourier_positions): return jnp.exp( 1j * jnp.sum(fourier_positions[:, None, :] * x[None, :, :], axis=-1)) @jax.jit def _weighted_sum_over_samples(kernel_probs, x_weights): if x_weights is None: return jnp.sum(kernel_probs, axis=1) else: return jnp.sum(x_weights[None, :] * kernel_probs, axis=1) @partial(jax.jit, static_argnames=["return_err"]) def _predict_kde_counts(x, x_weights, kernel_cens, cov, return_err=False): kernel_probs = _get_kernel_probs(x, kernel_cens, cov) kde_counts = _weighted_sum_over_samples(kernel_probs, x_weights) if return_err: x_weights_squared = None if x_weights is not None: x_weights_squared = x_weights ** 2 ess = kde_counts ** 2 / _weighted_sum_over_samples( kernel_probs ** 2, x_weights_squared) err = kde_counts / jnp.sqrt(ess) return kde_counts, err else: return kde_counts @partial(jax.jit, static_argnames=["return_err"]) def _predict_fourier(x, x_weights, fourier_positions, return_err=False): exponentials = _get_fourier_exponentials(x, fourier_positions) fourier_counts = _weighted_sum_over_samples( exponentials, x_weights) if return_err: x_weights_squared = None if x_weights is not None: x_weights_squared = x_weights ** 2 ess_real = fourier_counts.real**2 / _weighted_sum_over_samples( exponentials.real**2, x_weights_squared) ess_imag = fourier_counts.imag**2 / _weighted_sum_over_samples( exponentials.imag**2, x_weights_squared) err_real = jnp.abs(fourier_counts.real) / jnp.sqrt(ess_real) err_imag = jnp.abs(fourier_counts.imag) / jnp.sqrt(ess_imag) return fourier_counts, err_real + 1j * err_imag else: return fourier_counts