Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update from the main folk #1

Merged
merged 1 commit into from
Jun 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 268 additions & 29 deletions probml_utils/dp_mixgauss_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 1,8 @@
import jax.numpy as jnp

from jax import jit, vmap, random, lax
from collections import namedtuple
from functools import partial
from jax.scipy.special import gammaln
from jax.numpy.linalg import slogdet, solve


# The implementation of Normal Inverse Wishart distribution is directly copied from
# the note book of Scott Linderman:
# 'Implementing a Normal Inverse Wishart Distribution in Tensorflow Probability'
Expand Down Expand Up @@ -120,11 116,8 @@ def from_parameters(cls, params, **kwargs):
def sufficient_statistics(datapoint):
return (1.0, datapoint, jnp.outer(datapoint, datapoint), 1.0)


##############################################################################


@partial(jit, static_argnums=(1,))
def dp_mixgauss_sample(key, num_of_samples, dp_concentration, dp_base_measure):
"""Sampling from the Dirichlet process (DP) Gaussian mixture model.

Expand All @@ -139,31 132,277 @@ def dp_mixgauss_sample(key, num_of_samples, dp_concentration, dp_base_measure):
base measure of the Dirichlet process

Returns:
array(num_of_samples, dimension): mean value of Gaussian component of each sample
array(num_of_samples, dimension, dimension):
variance (covariance matrix) of the Gaussian component of each sample
dict: means and covariance matrices for the Gaussian distribution of each cluster
array(num_of_samples):
cluter index of each datum
array(num_of_samples, dimension): samples from the DP mixture model
"""
def sample_cluster_index(carry, _):
key, cluster_sizes, num_of_clusters = carry
key, subkey = random.split(key)
logits = jnp.log(cluster_sizes.at[num_of_clusters].set(dp_concentration))
new_cluster_index = random.categorical(subkey, logits)
cluster_sizes = cluster_sizes.at[new_cluster_index].add(1)
num_of_clusters = lax.cond(new_cluster_index==num_of_clusters,
lambda x: x 1,
lambda x: x,
num_of_clusters)
return (key, cluster_sizes, num_of_clusters), new_cluster_index
carry = (key, jnp.full(num_of_samples, 0), 0)
carry, cluster_indices = lax.scan(sample_cluster_index, carry, None, length=num_of_samples)
key, _, num_of_clusters = carry
# Generating distribution parameters for each observation from the base measure
# (redundant parameters in the same cluster are to be removed during sampling)
key, subkey = random.split(key)
cluster_parameters = dp_base_measure.sample(seed=subkey, sample_shape=(num_of_samples,))
def sample_update(carry, data_index):
key, cluster_means, cluster_covs = carry
cluster_parameters = dp_base_measure.sample(seed=subkey, sample_shape=(num_of_clusters,))
# Sampling
subkeys = random.split(key, num_of_samples)
cluster_means = cluster_parameters['mu'][cluster_indices]
cluster_covs = cluster_parameters['Sigma'][cluster_indices]
samples = vmap(random.multivariate_normal)(subkeys, cluster_means, cluster_covs)
return cluster_parameters, cluster_indices, samples


@jit
def log_pdf_multi_student(datum, nu, mu, sigma):
"""Log probability density function (pdf) of the multivariate T distribution

https://en.wikipedia.org/wiki/Multivariate_t-distribution

Args:
data (array(dim)): data point that we want to evaluate log probability density at
nu (int): degree of freedom of the multivariate T distribution
mu (array(dim)): location parameter of the multivariate T distribution
sigma (array(dim,dim)): positive-definite real scale matrix of the multivariate T distribution

Returns:
float: log probability density of the multivariate T distribution at x
"""
dim = mu.shape[0]
# logarithm of the normalizing constant
log_norm = gammaln((nu dim)/2.0) - (gammaln(nu/2.0)
dim/2.0*(jnp.log(nu) jnp.log(jnp.pi))
slogdet(sigma)[1])
# logarithm of the unnormalized pdf
log_like = -(nu dim)/2.0 * jnp.log(1 1/nu*(datum-mu).dot(solve(sigma, datum-mu)))
return log_norm log_like


@jit
def log_pdf_posterior_predictive_mvn(datum, sufficient_stat, num_of_observ, prior_params):
"""Log pdf of the posterior predictive multivariate T distribution

The distribution of data given the parameter is Gaussian,
and the prior distribution of parameters is the normal inverse Wishart (NIW).

Args:
sufficient_stat (dict): sufficient statistics of the observations to be conditioned on
prior_params (dict): parameters of NIW prior

Returns:
float: log pdf of the posterior predictive distribution
"""
# Parameters of the prior normal inverse Wishart distribution
mean_prior = prior_params['loc']
kappa_prior = prior_params['mean_precision']
df_prior = prior_params['df']
cov_prior = prior_params['scale']
# Sufficient statistics of the Gaussian likelihood
mean_data = sufficient_stat['mean']
cov_data = sufficient_stat['cov']
# Computing the posterior parameters
nu_pos = df_prior num_of_observ
kappa_pos = kappa_prior num_of_observ
mu_pos = kappa_prior/kappa_pos*mean_prior num_of_observ/kappa_pos*mean_data
sub = mean_data-mean_prior
lambda_pos = cov_prior cov_data kappa_prior*num_of_observ/kappa_pos*jnp.outer(sub, sub)
dim = len(mean_data)
# Computing parameters of the posterior predictive distribution
nu_predict = nu_pos - dim 1
mu_predict = mu_pos
sigma_predict = lambda_pos*(kappa_pos 1)/(kappa_pos*nu_predict)
return log_pdf_multi_student(datum, nu_predict, mu_predict, sigma_predict)


def gmm_gibbs(key, num_of_samples, data, precision, num_of_clusters, prior_params):
"""Gibbs sampling of the cluster assignment using Gaussian finite mixture model

The prior of the parameters of Gaussian likelihood is normal inverse Wishart (NIW)

Args:
key (jax.random.PRNGKey): seed of initial random cluster
num_of_samples (int): number of samples (gibbs iterations)
data (num_of_data, dimension): array of observations
precision (float): precision of a symmetric Dirichlet distribution,
which is the prior of the mixture weights
num_of_clusters (int): number of component in the mixture distribution
prior_params (dict): parameters of the NIW prior

Returns:
array(num_of_samples, num_of_data): samples of cluster assignment
"""
num_of_data = data.shape[0]
@jit
def cluster_sufficient_stats(cluster_members):
# Computing the sufficient statistics of one cluster,
# 'cluster_members' is an array of boolean variable indicating
# whether a datum is a member of the cluster
_mean = jnp.mean(data*jnp.atleast_2d(cluster_members).T, axis=0)
_sub = data*jnp.atleast_2d(cluster_members).T - _mean
_cov = _sub.T @ _sub
return {'mean':_mean, 'cov':_cov}
# Kernel for updating the cluster assignment of each single datum
def cluster_assign_per_datum(carry, datum_index):
key, cluster_assign, cluster_sizes, sufficient_stats = carry
cluster_index = cluster_assign[datum_index]
cluster_members = cluster_assign==cluster_index
# Removing the current datum from its current cluster
# before updating the sufficient statistics of this cluster
cluster_members = cluster_members.at[datum_index].set(False)
cluster_sizes = cluster_sizes.at[cluster_index].add(-1)
# Updating the sufficient statistics of the current cluster of the datum
stat = cluster_sufficient_stats(cluster_members)
sufficient_stats['mean'] = sufficient_stats['mean'].at[cluster_index].set(stat['mean'])
sufficient_stats['cov'] = sufficient_stats['cov'].at[cluster_index].set(stat['cov'])
# Assigning the data point to its new cluster
log_likes_per_cluster = vmap(log_pdf_posterior_predictive_mvn,
in_axes=(None, {'mean': 0, 'cov': 0}, 0, None))(
data[datum_index], sufficient_stats, cluster_sizes, prior_params)
logits = log_likes_per_cluster jnp.log(cluster_sizes precision/num_of_clusters)
key, subkey = random.split(key)
cluster_position = random.uniform(subkey, minval=0.0, maxval=data_index dp_concentration)
# new sample is assigned to a new cluster if the uniform random variable > data_index
# otherwise its distribution parameter is set equal to that of an existing sample
_mean, _cov = lax.cond(cluster_position > data_index,
lambda x: (cluster_means[data_index], cluster_covs[data_index]),
lambda x: (cluster_means[x], cluster_covs[x]),
cluster_position.astype(int))
cluster_means = cluster_means.at[data_index].set(_mean)
cluster_covs = cluster_covs.at[data_index].set(_cov)
new_cluster_index = random.categorical(subkey, logits)
cluster_assign = cluster_assign.at[datum_index].set(new_cluster_index)
cluster_sizes = cluster_sizes.at[new_cluster_index].add(1)
# Updating the sufficient statistics for the new cluster
cluster_members = cluster_assign==new_cluster_index
new_stat = cluster_sufficient_stats(cluster_members)
sufficient_stats['mean'] = sufficient_stats['mean'].at[new_cluster_index].set(new_stat['mean'])
sufficient_stats['cov'] = sufficient_stats['cov'].at[new_cluster_index].set(new_stat['cov'])
return (key, cluster_assign, cluster_sizes, sufficient_stats), None
# Kernel for each gibbs iteration
def update_per_itr(carry, key):
# Shuffling the order of the dataset
shuffled_indices = random.permutation(key, jnp.arange(num_of_data))
carry, _ = lax.scan(cluster_assign_per_datum,
carry,
shuffled_indices)
return carry, carry[1]
# Initialization by assigning data using prior distribution
key, *subkey = random.split(key, 3)
cluster_weights = random.dirichlet(subkey[0], precision/num_of_clusters*jnp.ones(num_of_clusters))
cluster_assign = random.categorical(subkey[1], jnp.log(cluster_weights), shape=(num_of_data,))
cluster_sizes = vmap(lambda x: jnp.sum(cluster_assign==x))(jnp.arange(num_of_clusters))
sufficient_stats = vmap(lambda x: cluster_sufficient_stats(cluster_assign==x))(
jnp.arange(num_of_clusters))
carry = key, cluster_assign, cluster_sizes, sufficient_stats
# Sampling
subkeys = random.split(key, num_of_samples)
carry, samples_of_cluster_assign = lax.scan(update_per_itr, carry, subkeys)
return samples_of_cluster_assign


def dp_mixgauss_gibbs(key, num_of_samples, data, concentration, prior_params):
"""Gibbs sampling of the cluster assignment using Dirichlet process (DP) Gaussian mixture model

This is also known as collapsed sampling of DP mixture model.
The prior of the parameters of the Gaussian likelihood is normal inverse Wishart (NIW)

Args:
key (jax.random.PRNGKey): seed of initial random sampler
num_of_samples (int): number of samples (gibbs iterations)
data (num_of_data, dimension): array of observations
concentration (float): concentration parameter of the DP
prior_params (dict): parameters of the NIW prior

Returns:
array(num_of_samples, num_of_data): samples of cluster assignment
"""
num_of_data, dim = data.shape
@jit
def cluster_sufficient_stats(cluster_members):
_mean = jnp.mean(data*jnp.atleast_2d(cluster_members).T, axis=0)
_sub = data*jnp.atleast_2d(cluster_members).T - _mean
_cov = _sub.T @ _sub
return {'mean':_mean, 'cov':_cov}
# Kernel for updating the cluster assignment for each single datum
def cluster_assign_per_datum(carry, datum_index):
key, cluster_assign, cluster_sizes, sufficient_stats = carry
cluster_index = cluster_assign[datum_index]
cluster_members = cluster_assign==cluster_index
# Removing the current datum from its current cluster
# before updating the sufficient statistics of this cluster
cluster_members = cluster_members.at[datum_index].set(False)
cluster_sizes = cluster_sizes.at[cluster_index].add(-1)
# Updating the sufficient statistics of the current cluster of the datum
stat = cluster_sufficient_stats(cluster_members)
sufficient_stats['mean'] = sufficient_stats['mean'].at[cluster_index].set(stat['mean'])
sufficient_stats['cov'] = sufficient_stats['cov'].at[cluster_index].set(stat['cov'])
# Updating the weights of each cluster
log_likes_per_cluster = vmap(_log_pdf_of_nonempty_cluster,
in_axes=(None, {'mean': 0, 'cov': 0}, 0)
)(data[datum_index], sufficient_stats, cluster_sizes)
logits = log_likes_per_cluster jnp.log(cluster_sizes)
# Adding (temporarily )the next cluster that could be introduced,
# setting the cluster index to be the index of the first empty cluster
log_like_next_cluster = log_pdf_posterior_predictive_mvn(data[datum_index],
{'mean': jnp.zeros(dim),
'cov': jnp.zeros((dim, dim))},
0,
prior_params)
next_cluster = jnp.asarray(cluster_sizes==0).nonzero(size=1)[0][0]
logits = logits.at[next_cluster].set(log_like_next_cluster jnp.log(concentration))
# Sampling the cluster index of the datum
key, subkey = random.split(key)
sample = random.multivariate_normal(subkey, _mean, _cov)
return (key, cluster_means, cluster_covs), sample
carry = (key, cluster_parameters['mu'], cluster_parameters['Sigma'])
carry, samples = lax.scan(sample_update, carry, jnp.arange(num_of_samples))
key, cluster_means, cluster_covs = carry
return cluster_means, cluster_covs, samples
new_cluster_index = random.categorical(subkey, logits)
cluster_assign = cluster_assign.at[datum_index].set(new_cluster_index)
cluster_sizes = cluster_sizes.at[new_cluster_index].add(1)
# Updating the sufficient statistics for the new cluster
new_cluster_members = cluster_assign==new_cluster_index
new_stat = cluster_sufficient_stats(new_cluster_members)
sufficient_stats['mean'] = sufficient_stats['mean'].at[new_cluster_index].set(new_stat['mean'])
sufficient_stats['cov'] = sufficient_stats['cov'].at[new_cluster_index].set(new_stat['cov'])
return (key, cluster_assign, cluster_sizes, sufficient_stats), None
@jit
def _log_pdf_of_nonempty_cluster(datum, suff_stat, cluster_size):
# Return -inf if the cluster if empty
# otherwise run the log_pdf_posterior_predictive_mvn
return lax.cond(cluster_size>0,
lambda _: log_pdf_posterior_predictive_mvn(datum,
suff_stat,
cluster_size,
prior_params),
lambda _: -jnp.inf,
None)
# Kernel for each gibbs iteration
def update_per_itr(carry, key):
shuffled_indices = random.permutation(key, jnp.arange(num_of_data))
carry, _ = lax.scan(cluster_assign_per_datum,
carry,
shuffled_indices)
return carry, carry[1]
# Initialization using the prior Chinese restaurant process
def chinese_restaurant_process(carry, key):
cluster_sizes, num_of_clusters = carry
logits = jnp.log(cluster_sizes.at[num_of_clusters].set(concentration))
new_cluster_index = random.categorical(key, logits)
cluster_sizes = cluster_sizes.at[new_cluster_index].add(1)
num_of_clusters = lax.cond(new_cluster_index==num_of_clusters,
lambda x: x 1,
lambda x: x,
num_of_clusters)
return (cluster_sizes, num_of_clusters), new_cluster_index
key, subkey = random.split(key)
carry_crp = (jnp.full(num_of_data, 0), 0)
carry_crp, cluster_assign = lax.scan(chinese_restaurant_process,
carry_crp,
random.split(subkey, num_of_data))
cluster_sizes, num_of_clusters = carry_crp
sufficient_stats = {'mean':jnp.zeros((num_of_data, dim)),
'cov':jnp.zeros((num_of_data, dim, dim))}
stats = vmap(lambda x: cluster_sufficient_stats(cluster_assign==x))(jnp.arange(num_of_clusters))
sufficient_stats['mean'] = sufficient_stats['mean'].at[jnp.arange(num_of_clusters)].set(stats['mean'])
sufficient_stats['cov'] = sufficient_stats['cov'].at[jnp.arange(num_of_clusters)].set(stats['cov'])
carry = key, cluster_assign, cluster_sizes, sufficient_stats
subkeys = random.split(key, num_of_samples)
# Sampling
carry, samples_of_cluster_assign = lax.scan(update_per_itr, carry, subkeys)
return samples_of_cluster_assign