diff options
author | 2016-07-27 08:11:19 -0800 | |
---|---|---|
committer | 2016-07-27 09:17:54 -0700 | |
commit | 4e3d98baac3144b2997c576985d5409bf48fd8db (patch) | |
tree | caa0660d80a7a507a69960c25e4daae0452e8c8f | |
parent | 497dfa6b80abf304e0ef86fc09e1f3f2f8c69a7c (diff) |
Minor fixes to distributions.
- Use randomint instead of geometric (for tests).
- Ensure Bernoulli works for float64.
- Use seed for exponential sampling.
- Remove allow_arbitrary_counts in favor of validate_args.
- Remove casting between parameter dtypes, in favor of user passing in parameters of same dtype.
Change: 128592132
4 files changed, 63 insertions, 67 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 866fb45524..1a3f5eaf66 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -61,14 +61,14 @@ class DirichletMultinomialTest(tf.test.TestCase): n = [[5.]] with self.test_session(): dist = tf.contrib.distributions.DirichletMultinomial(n, alpha) - dist.pmf([2, 3, 0]).eval() - dist.pmf([3, 0, 2]).eval() + dist.pmf([2., 3, 0]).eval() + dist.pmf([3., 0, 2]).eval() with self.assertRaisesOpError('Condition x >= 0.*'): - dist.pmf([-1, 4, 2]).eval() + dist.pmf([-1., 4, 2]).eval() with self.assertRaisesOpError('Condition x == y.*'): - dist.pmf([3, 3, 0]).eval() + dist.pmf([3., 3, 0]).eval() - def testPmfArbitraryCounts(self): + def testPmf_non_integer_counts(self): alpha = [[1., 2, 3]] n = [[5.]] with self.test_session(): @@ -80,8 +80,10 @@ class DirichletMultinomialTest(tf.test.TestCase): with self.assertRaisesOpError('Condition x == y.*'): dist.pmf([1.0, 2.5, 1.5]).eval() dist = tf.contrib.distributions.DirichletMultinomial( - n, alpha, allow_arbitrary_counts=True) - dist.pmf(np.array([1.0, 2.5, 1.5])).eval() + n, alpha, validate_args=False) + dist.pmf([1., 2., 3.]).eval() + # Non-integer arguments work. + dist.pmf([1.0, 2.5, 1.5]).eval() def testPmfBothZeroBatches(self): # The probabilities of one vote falling into class k is the mean for class @@ -90,7 +92,7 @@ class DirichletMultinomialTest(tf.test.TestCase): # Both zero-batches. No broadcast alpha = [1., 2] counts = [1., 0] - dist = tf.contrib.distributions.DirichletMultinomial(1, alpha) + dist = tf.contrib.distributions.DirichletMultinomial(1., alpha) pmf = dist.pmf(counts) self.assertAllClose(1 / 3., pmf.eval()) self.assertEqual((), pmf.get_shape()) @@ -102,7 +104,7 @@ class DirichletMultinomialTest(tf.test.TestCase): # Both zero-batches. No broadcast alpha = [1., 2] counts = [3., 2] - dist = tf.contrib.distributions.DirichletMultinomial(5, alpha) + dist = tf.contrib.distributions.DirichletMultinomial(5., alpha) pmf = dist.pmf(counts) self.assertAllClose(1 / 7., pmf.eval()) self.assertEqual((), pmf.get_shape()) @@ -113,7 +115,7 @@ class DirichletMultinomialTest(tf.test.TestCase): with self.test_session(): alpha = [1., 2] counts = [3., 2] - n = np.full([4, 3], 5.) + n = np.full([4, 3], 5., dtype=np.float32) dist = tf.contrib.distributions.DirichletMultinomial(n, alpha) pmf = dist.pmf(counts) self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval()) @@ -125,7 +127,7 @@ class DirichletMultinomialTest(tf.test.TestCase): with self.test_session(): alpha = [[1., 2]] counts = [[1., 0], [0., 1]] - dist = tf.contrib.distributions.DirichletMultinomial([1], alpha) + dist = tf.contrib.distributions.DirichletMultinomial([1.], alpha) pmf = dist.pmf(counts) self.assertAllClose([1 / 3., 2 / 3.], pmf.eval()) self.assertEqual((2), pmf.get_shape()) @@ -231,12 +233,12 @@ class DirichletMultinomialTest(tf.test.TestCase): def testVariance_n_alpha_broadcast(self): alpha_v = [1., 2, 3] - alpha_0 = np.sum(alpha_v) + alpha_0 = 6. # Shape [4, 3] - alpha = np.array(4 * [alpha_v]) + alpha = np.array(4 * [alpha_v], dtype=np.float32) # Shape [4, 1] - ns = np.array([[2.], [3.], [4.], [5.]]) + ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32) variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum) covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2 @@ -250,7 +252,7 @@ class DirichletMultinomialTest(tf.test.TestCase): covariance_entry(alpha_v[1], alpha_v[2], alpha_0)], [covariance_entry(alpha_v[2], alpha_v[0], alpha_0), covariance_entry(alpha_v[2], alpha_v[1], alpha_0), - variance_entry(alpha_v[2], alpha_0)]]]) + variance_entry(alpha_v[2], alpha_0)]]], dtype=np.float32) with self.test_session(): # ns is shape [4, 1], and alpha is shape [4, 3]. @@ -263,11 +265,11 @@ class DirichletMultinomialTest(tf.test.TestCase): self.assertAllClose(expected_variance, variance.eval()) def testVariance_multidimensional(self): - alpha = np.random.rand(3, 5, 4) - alpha2 = np.random.rand(6, 3, 3) - # Ensure n > 0. - ns = np.random.geometric(p=0.8, size=[3, 5, 1]) + 1 - ns2 = np.random.geometric(p=0.8, size=[6, 1, 1]) + 1 + alpha = np.random.rand(3, 5, 4).astype(np.float32) + alpha2 = np.random.rand(6, 3, 3).astype(np.float32) + + ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32) + ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32) with self.test_session(): dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha) @@ -297,7 +299,7 @@ class DirichletMultinomialTest(tf.test.TestCase): # One (three sided) coin flip. Prob[coin 3] = 0.8. # Note that since it was one flip, value of tau didn't matter. - counts = [0, 0, 1] + counts = [0., 0, 1] with self.test_session(): dist = tf.contrib.distributions.DirichletMultinomial(1., alpha) pmf = dist.pmf(counts) @@ -305,9 +307,9 @@ class DirichletMultinomialTest(tf.test.TestCase): self.assertEqual((), pmf.get_shape()) # Two (three sided) coin flips. Prob[coin 3] = 0.8. - counts = [0, 0, 2] + counts = [0., 0, 2] with self.test_session(): - dist = tf.contrib.distributions.DirichletMultinomial(2, alpha) + dist = tf.contrib.distributions.DirichletMultinomial(2., alpha) pmf = dist.pmf(counts) self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2) self.assertEqual((), pmf.get_shape()) @@ -315,7 +317,7 @@ class DirichletMultinomialTest(tf.test.TestCase): # Three (three sided) coin flips. counts = [1., 0, 2] with self.test_session(): - dist = tf.contrib.distributions.DirichletMultinomial(3, alpha) + dist = tf.contrib.distributions.DirichletMultinomial(3., alpha) pmf = dist.pmf(counts) self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2) self.assertEqual((), pmf.get_shape()) @@ -336,10 +338,10 @@ class DirichletMultinomialTest(tf.test.TestCase): self.assertEqual((), pmf.get_shape()) # If there are two draws, it is much more likely that they are the same. - counts_same = [2, 0] + counts_same = [2., 0] counts_different = [1, 1.] with self.test_session(): - dist = tf.contrib.distributions.DirichletMultinomial(2, alpha) + dist = tf.contrib.distributions.DirichletMultinomial(2., alpha) pmf_same = dist.pmf(counts_same) pmf_different = dist.pmf(counts_different) self.assertLess(5 * pmf_different.eval(), pmf_same.eval()) diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py index b3259b2867..fe5826e491 100644 --- a/tensorflow/contrib/distributions/python/ops/bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -20,12 +20,14 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops @@ -89,9 +91,11 @@ class Bernoulli(distribution.Distribution): elif logits is None: with ops.name_scope(name): with ops.name_scope("p"): - with ops.control_dependencies([check_op(p, 1.), check_op(0., p)] if - validate_args else []): - self._p = array_ops.identity(p) + p = array_ops.identity(p) + one = constant_op.constant(1., p.dtype) + zero = constant_op.constant(0., p.dtype) + self._p = control_flow_ops.with_dependencies( + [check_op(p, one), check_op(zero, p)] if validate_args else [], p) with ops.name_scope("logits"): self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p) with ops.name_scope(name): diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py index 6982a73381..7c779fff06 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -132,25 +132,21 @@ class DirichletMultinomial(distribution.Distribution): def __init__(self, n, alpha, - allow_arbitrary_counts=False, validate_args=True, allow_nan_stats=False, name='DirichletMultinomial'): """Initialize a batch of DirichletMultinomial distributions. Args: - n: Non-negative `float` or `double` tensor with shape - broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch - of `N1 x ... x Nm` different Dirichlet multinomial distributions. Its - components should be equal to integral values. - alpha: Positive `float` or `double` tensor with shape broadcastable to - `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` - different `k` class Dirichlet multinomial distributions. - allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf - allows for the `counts` tensor to be non-integral values. - The pmf/cdf are functions that can be evaluated at non-integral values, - but are only a distribution over non-negative integers. If - `validate_args` is `False`, this assertion is turned off. + n: Non-negative `float` or `double` tensor, whose dtype is the same as + `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`. + Defines this as a batch of `N1 x ... x Nm` different Dirichlet + multinomial distributions. Its components should be equal to integral + values. + alpha: Positive `float` or `double` tensor, whose dtype is the same as + `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines + this as a batch of `N1 x ... x Nm` different `k` class Dirichlet + multinomial distributions. validate_args: Whether to assert valid values for parameters `alpha` and `n`, and `x` in `prob` and `log_prob`. If False, correct behavior is not guaranteed. @@ -174,7 +170,6 @@ class DirichletMultinomial(distribution.Distribution): self._allow_nan_stats = allow_nan_stats self._validate_args = validate_args self._name = name - self._allow_arbitrary_counts = allow_arbitrary_counts with ops.op_scope([n, alpha], name): # Broadcasting works because: # * The broadcasting convention is to prepend dimensions of size [1], and @@ -186,8 +181,7 @@ class DirichletMultinomial(distribution.Distribution): # * All calls involving `counts` eventually require a broadcast between # `counts` and alpha. self._alpha = self._check_alpha(alpha) - n = self._check_n(n) - self._n = math_ops.cast(n, self._alpha.dtype) + self._n = self._check_n(n) self._alpha_sum = math_ops.reduce_sum( self._alpha, reduction_indices=[-1], keep_dims=False) @@ -346,12 +340,12 @@ class DirichletMultinomial(distribution.Distribution): probability includes a combinatorial coefficient. Args: - counts: Non-negative `float` or `double` tensor whose shape can - be broadcast with `self.alpha`. For fixed leading dimensions, the last - dimension represents counts for the corresponding Dirichlet Multinomial - distribution in `self.alpha`. `counts` is only legal if it sums up to - `n` and its components are equal to integral values. The second - condition is relaxed if `allow_arbitrary_counts` is set. + counts: Non-negative `float` or `double` tensor whose dtype is the same + `self` and whose shape can be broadcast with `self.alpha`. For fixed + leading dimensions, the last dimension represents counts for the + corresponding Dirichlet Multinomial distribution in `self.alpha`. + `counts` is only legal if it sums up to `n` and its components are + equal to integral values. name: Name to give this Op, defaults to "log_prob". Returns: @@ -362,8 +356,6 @@ class DirichletMultinomial(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([n, alpha, counts], name): counts = self._check_counts(counts) - # Use the same dtype as alpha for computations. - counts = math_ops.cast(counts, self.dtype) ordered_prob = (special_math_ops.lbeta(alpha + counts) - special_math_ops.lbeta(alpha)) @@ -390,12 +382,12 @@ class DirichletMultinomial(distribution.Distribution): probability includes a combinatorial coefficient. Args: - counts: Non-negative `float`, `double` tensor whose shape can - be broadcast with `self.alpha`. For fixed leading dimensions, the last - dimension represents counts for the corresponding Dirichlet Multinomial - distribution in `self.alpha`. `counts` is only legal if it sums up to - `n` and its components are equal to integral values. The second - condition is relaxed if `allow_arbitrary_counts` is set. + counts: Non-negative `float` or `double` tensor whose dtype is the same + `self` and whose shape can be broadcast with `self.alpha`. For fixed + leading dimensions, the last dimension represents counts for the + corresponding Dirichlet Multinomial distribution in `self.alpha`. + `counts` is only legal if it sums up to `n` and its components are + equal to integral values. name: Name to give this Op, defaults to "prob". Returns: @@ -409,14 +401,11 @@ class DirichletMultinomial(distribution.Distribution): if not self.validate_args: return counts candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) - dependencies = [check_ops.assert_non_negative(counts), - check_ops.assert_equal(self._n, - math_ops.cast(candidate_n, - self._n.dtype))] - if not self._allow_arbitrary_counts: - dependencies += [_assert_integer_form(counts)] - - return control_flow_ops.with_dependencies(dependencies, counts) + + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative(counts), + check_ops.assert_equal(self._n, candidate_n), + _assert_integer_form(counts)], counts) def _check_alpha(self, alpha): alpha = ops.convert_to_tensor(alpha, name='alpha') diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py index 13b26a11db..c49b3eeba8 100644 --- a/tensorflow/contrib/distributions/python/ops/exponential.py +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -102,6 +102,7 @@ class Exponential(gamma.Gamma): shape, minval=np.nextafter( self.dtype.as_numpy_dtype(0.), self.dtype.as_numpy_dtype(1.)), maxval=constant_op.constant(1.0, dtype=self.dtype), + seed=seed, dtype=self.dtype) n_val = tensor_util.constant_value(n) |