aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 13:56:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 14:24:34 -0800
commit23b28f02e551c4adbb9fac08c8968caa894cab69 (patch)
tree65ea0714c7b315f8a5b2ed3db4e35095b11d8574
parent9490027f2dd44ec367dc0e675921bd2404fb71c0 (diff)
Update the KL divergence calculation to allow NaNs by default.
Change: 147640694
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py28
-rw-r--r--tensorflow/contrib/distributions/python/ops/kullback_leibler.py15
2 files changed, 21 insertions, 22 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
index 2eddb1bd66..6b3d886e01 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py
@@ -43,8 +43,7 @@ class KLTest(test.TestCase):
return name
a = MyDist(loc=0.0, scale=1.0)
- # Run kl() with allow_nan=True because strings can't go through is_nan.
- self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
+ self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK"))
def testDomainErrorExceptions(self):
@@ -61,11 +60,11 @@ class KLTest(test.TestCase):
with self.test_session():
a = MyDistException(loc=0.0, scale=1.0)
- kl = kullback_leibler.kl(a, a)
+ kl = kullback_leibler.kl(a, a, allow_nan_stats=False)
with self.assertRaisesOpError(
"KL calculation between .* and .* returned NaN values"):
kl.eval()
- kl_ok = kullback_leibler.kl(a, a, allow_nan=True)
+ kl_ok = kullback_leibler.kl(a, a)
self.assertAllEqual([float("nan")], kl_ok.eval())
def testRegistrationFailures(self):
@@ -117,17 +116,16 @@ class KLTest(test.TestCase):
sub2 = Sub2(loc=0.0, scale=1.0)
sub11 = Sub11(loc=0.0, scale=1.0)
- self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True))
- self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True))
- self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1, allow_nan=True))
- self.assertEqual(
- "sub1-1", kullback_leibler.kl(sub11, sub11, allow_nan=True))
- self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True))
- self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True))
- self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True))
- self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True))
- self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11, allow_nan=True))
- self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11, allow_nan=True))
+ self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1))
+ self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2))
+ self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1))
+ self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub11))
+ self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1))
+ self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2))
+ self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1))
+ self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2))
+ self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11))
+ self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
index f24f01235a..bb94a87680 100644
--- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
+++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
@@ -45,7 +45,7 @@ def _registered_kl(type_a, type_b):
return kl_fn
-def kl(dist_a, dist_b, allow_nan=False, name=None):
+def kl(dist_a, dist_b, allow_nan_stats=True, name=None):
"""Get the KL-divergence KL(dist_a || dist_b).
If there is no KL method registered specifically for `type(dist_a)` and
@@ -64,10 +64,11 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
Args:
dist_a: The first distribution.
dist_b: The second distribution.
- allow_nan: If `False` (default), a runtime error is raised
- if the KL returns NaN values for any batch entry of the given
- distributions. If `True`, the KL may return a NaN for the given entry.
- name: (optional) Name scope to use for created operations.
+ allow_nan_stats: Python `bool`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value "`NaN`" to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
Returns:
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
@@ -84,7 +85,7 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
with ops.name_scope("KullbackLeibler"):
kl_t = kl_fn(dist_a, dist_b, name=name)
- if allow_nan:
+ if allow_nan_stats:
return kl_t
# Check KL for NaNs
@@ -95,7 +96,7 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
math_ops.logical_not(
math_ops.reduce_any(math_ops.is_nan(kl_t))),
["KL calculation between %s and %s returned NaN values "
- "(and was called with allow_nan=False). Values:"
+ "(and was called with allow_nan_stats=False). Values:"
% (dist_a.name, dist_b.name), kl_t])]):
return array_ops.identity(kl_t, name="checked_kl")