diff options
author | Joshua V. Dillon <jvdillon@google.com> | 2017-01-10 16:36:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-10 16:45:51 -0800 |
commit | 973d5afdb68addd1315ceda1c536c88232699756 (patch) | |
tree | efa56f07f922f558845a37d836211cf32790e59a | |
parent | 56b74296829c30aa341a2d8ee5b4e2dbb48bc274 (diff) |
Implement KL-divergence between two Gamma distributions.
Change: 144144041
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py | 33 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/gamma.py | 28 |
2 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py index 31027736fa..2195ed0749 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py @@ -18,11 +18,15 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import special from scipy import stats + from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib +from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test @@ -317,6 +321,35 @@ class GammaTest(test.TestCase): self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), gamma.alpha.eval()) self.assertAllEqual(nn_ops.softplus(beta_v).eval(), gamma.beta.eval()) + def testGammaGammaKL(self): + alpha0 = np.array([3.]) + beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) + + alpha1 = np.array([0.4]) + beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) + + # Build graph. + with self.test_session() as sess: + g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0) + g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1) + x = g0.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) + kl_actual = kullback_leibler.kl(g0, g1) + + # Execute graph. + [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) + + kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) + + special.gammaln(alpha1) + - special.gammaln(alpha0) + + alpha1 * np.log(beta0) + - alpha1 * np.log(beta1) + + alpha0 * (beta1 / beta0 - 1.)) + + self.assertEqual(beta0.shape, kl_actual.get_shape()) + self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) + self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index 23ee15e432..977ea75f00 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -240,3 +241,30 @@ class GammaWithSoftplusAlphaBeta(Gamma): allow_nan_stats=allow_nan_stats, name=ns) self._parameters = parameters + + +@kullback_leibler.RegisterKL(Gamma, Gamma) +def _kl_gamma_gamma(g0, g1, name=None): + """Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma. + + Args: + g0: instance of a Gamma distribution object. + g1: instance of a Gamma distribution object. + name: (optional) Name to use for created operations. + Default is "kl_gamma_gamma". + + Returns: + kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1). + """ + with ops.name_scope(name, "kl_gamma_gamma", + values=[g0.alpha, g0.beta, g1.alpha, g1.beta]): + # Result from: + # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps + # For derivation see: + # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long + return ((g0.alpha - g1.alpha) * math_ops.digamma(g0.alpha) + + math_ops.lgamma(g1.alpha) + - math_ops.lgamma(g0.alpha) + + g1.alpha * math_ops.log(g0.beta) + - g1.alpha * math_ops.log(g1.beta) + + g0.alpha * (g1.beta / g0.beta - 1.)) |