aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-01-10 16:36:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-10 16:45:51 -0800
commit973d5afdb68addd1315ceda1c536c88232699756 (patch)
treeefa56f07f922f558845a37d836211cf32790e59a
parent56b74296829c30aa341a2d8ee5b4e2dbb48bc274 (diff)
Implement KL-divergence between two Gamma distributions.
Change: 144144041
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py33
-rw-r--r--tensorflow/contrib/distributions/python/ops/gamma.py28
2 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
index 31027736fa..2195ed0749 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
@@ -18,11 +18,15 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from scipy import special
from scipy import stats
+
from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib
+from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@@ -317,6 +321,35 @@ class GammaTest(test.TestCase):
self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), gamma.alpha.eval())
self.assertAllEqual(nn_ops.softplus(beta_v).eval(), gamma.beta.eval())
+ def testGammaGammaKL(self):
+ alpha0 = np.array([3.])
+ beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5])
+
+ alpha1 = np.array([0.4])
+ beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
+
+ # Build graph.
+ with self.test_session() as sess:
+ g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0)
+ g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1)
+ x = g0.sample(int(1e4), seed=0)
+ kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
+ kl_actual = kullback_leibler.kl(g0, g1)
+
+ # Execute graph.
+ [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])
+
+ kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0)
+ + special.gammaln(alpha1)
+ - special.gammaln(alpha0)
+ + alpha1 * np.log(beta0)
+ - alpha1 * np.log(beta1)
+ + alpha0 * (beta1 / beta0 - 1.))
+
+ self.assertEqual(beta0.shape, kl_actual.get_shape())
+ self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
+ self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py
index 23ee15e432..977ea75f00 100644
--- a/tensorflow/contrib/distributions/python/ops/gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/gamma.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -240,3 +241,30 @@ class GammaWithSoftplusAlphaBeta(Gamma):
allow_nan_stats=allow_nan_stats,
name=ns)
self._parameters = parameters
+
+
+@kullback_leibler.RegisterKL(Gamma, Gamma)
+def _kl_gamma_gamma(g0, g1, name=None):
+ """Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma.
+
+ Args:
+ g0: instance of a Gamma distribution object.
+ g1: instance of a Gamma distribution object.
+ name: (optional) Name to use for created operations.
+ Default is "kl_gamma_gamma".
+
+ Returns:
+ kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1).
+ """
+ with ops.name_scope(name, "kl_gamma_gamma",
+ values=[g0.alpha, g0.beta, g1.alpha, g1.beta]):
+ # Result from:
+ # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps
+ # For derivation see:
+ # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long
+ return ((g0.alpha - g1.alpha) * math_ops.digamma(g0.alpha)
+ + math_ops.lgamma(g1.alpha)
+ - math_ops.lgamma(g0.alpha)
+ + g1.alpha * math_ops.log(g0.beta)
+ - g1.alpha * math_ops.log(g1.beta)
+ + g0.alpha * (g1.beta / g0.beta - 1.))