diff options
author | 2017-02-15 13:56:58 -0800 | |
---|---|---|
committer | 2017-02-15 14:24:34 -0800 | |
commit | 23b28f02e551c4adbb9fac08c8968caa894cab69 (patch) | |
tree | 65ea0714c7b315f8a5b2ed3db4e35095b11d8574 | |
parent | 9490027f2dd44ec367dc0e675921bd2404fb71c0 (diff) |
Update the KL divergence calculation to allow NaNs by default.
Change: 147640694
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py | 28 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/kullback_leibler.py | 15 |
2 files changed, 21 insertions, 22 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py index 2eddb1bd66..6b3d886e01 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py @@ -43,8 +43,7 @@ class KLTest(test.TestCase): return name a = MyDist(loc=0.0, scale=1.0) - # Run kl() with allow_nan=True because strings can't go through is_nan. - self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK")) + self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK")) def testDomainErrorExceptions(self): @@ -61,11 +60,11 @@ class KLTest(test.TestCase): with self.test_session(): a = MyDistException(loc=0.0, scale=1.0) - kl = kullback_leibler.kl(a, a) + kl = kullback_leibler.kl(a, a, allow_nan_stats=False) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl.eval() - kl_ok = kullback_leibler.kl(a, a, allow_nan=True) + kl_ok = kullback_leibler.kl(a, a) self.assertAllEqual([float("nan")], kl_ok.eval()) def testRegistrationFailures(self): @@ -117,17 +116,16 @@ class KLTest(test.TestCase): sub2 = Sub2(loc=0.0, scale=1.0) sub11 = Sub11(loc=0.0, scale=1.0) - self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True)) - self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1, allow_nan=True)) - self.assertEqual( - "sub1-1", kullback_leibler.kl(sub11, sub11, allow_nan=True)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) - self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11, allow_nan=True)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11, allow_nan=True)) + self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2)) + self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1)) + self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub11)) + self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2)) + self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2)) + self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11)) + self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py index f24f01235a..bb94a87680 100644 --- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py +++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py @@ -45,7 +45,7 @@ def _registered_kl(type_a, type_b): return kl_fn -def kl(dist_a, dist_b, allow_nan=False, name=None): +def kl(dist_a, dist_b, allow_nan_stats=True, name=None): """Get the KL-divergence KL(dist_a || dist_b). If there is no KL method registered specifically for `type(dist_a)` and @@ -64,10 +64,11 @@ def kl(dist_a, dist_b, allow_nan=False, name=None): Args: dist_a: The first distribution. dist_b: The second distribution. - allow_nan: If `False` (default), a runtime error is raised - if the KL returns NaN values for any batch entry of the given - distributions. If `True`, the KL may return a NaN for the given entry. - name: (optional) Name scope to use for created operations. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. Returns: A Tensor with the batchwise KL-divergence between dist_a and dist_b. @@ -84,7 +85,7 @@ def kl(dist_a, dist_b, allow_nan=False, name=None): with ops.name_scope("KullbackLeibler"): kl_t = kl_fn(dist_a, dist_b, name=name) - if allow_nan: + if allow_nan_stats: return kl_t # Check KL for NaNs @@ -95,7 +96,7 @@ def kl(dist_a, dist_b, allow_nan=False, name=None): math_ops.logical_not( math_ops.reduce_any(math_ops.is_nan(kl_t))), ["KL calculation between %s and %s returned NaN values " - "(and was called with allow_nan=False). Values:" + "(and was called with allow_nan_stats=False). Values:" % (dist_a.name, dist_b.name), kl_t])]): return array_ops.identity(kl_t, name="checked_kl") |