aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-27 08:11:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-27 09:17:54 -0700
commit4e3d98baac3144b2997c576985d5409bf48fd8db (patch)
treecaa0660d80a7a507a69960c25e4daae0452e8c8f
parent497dfa6b80abf304e0ef86fc09e1f3f2f8c69a7c (diff)
Minor fixes to distributions.
- Use randomint instead of geometric (for tests). - Ensure Bernoulli works for float64. - Use seed for exponential sampling. - Remove allow_arbitrary_counts in favor of validate_args. - Remove casting between parameter dtypes, in favor of user passing in parameters of same dtype. Change: 128592132
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py54
-rw-r--r--tensorflow/contrib/distributions/python/ops/bernoulli.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py65
-rw-r--r--tensorflow/contrib/distributions/python/ops/exponential.py1
4 files changed, 63 insertions, 67 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
index 866fb45524..1a3f5eaf66 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
@@ -61,14 +61,14 @@ class DirichletMultinomialTest(tf.test.TestCase):
n = [[5.]]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
- dist.pmf([2, 3, 0]).eval()
- dist.pmf([3, 0, 2]).eval()
+ dist.pmf([2., 3, 0]).eval()
+ dist.pmf([3., 0, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
- dist.pmf([-1, 4, 2]).eval()
+ dist.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('Condition x == y.*'):
- dist.pmf([3, 3, 0]).eval()
+ dist.pmf([3., 3, 0]).eval()
- def testPmfArbitraryCounts(self):
+ def testPmf_non_integer_counts(self):
alpha = [[1., 2, 3]]
n = [[5.]]
with self.test_session():
@@ -80,8 +80,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.assertRaisesOpError('Condition x == y.*'):
dist.pmf([1.0, 2.5, 1.5]).eval()
dist = tf.contrib.distributions.DirichletMultinomial(
- n, alpha, allow_arbitrary_counts=True)
- dist.pmf(np.array([1.0, 2.5, 1.5])).eval()
+ n, alpha, validate_args=False)
+ dist.pmf([1., 2., 3.]).eval()
+ # Non-integer arguments work.
+ dist.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
# The probabilities of one vote falling into class k is the mean for class
@@ -90,7 +92,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [1., 0]
- dist = tf.contrib.distributions.DirichletMultinomial(1, alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1 / 3., pmf.eval())
self.assertEqual((), pmf.get_shape())
@@ -102,7 +104,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [3., 2]
- dist = tf.contrib.distributions.DirichletMultinomial(5, alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial(5., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1 / 7., pmf.eval())
self.assertEqual((), pmf.get_shape())
@@ -113,7 +115,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.test_session():
alpha = [1., 2]
counts = [3., 2]
- n = np.full([4, 3], 5.)
+ n = np.full([4, 3], 5., dtype=np.float32)
dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
pmf = dist.pmf(counts)
self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, pmf.eval())
@@ -125,7 +127,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
with self.test_session():
alpha = [[1., 2]]
counts = [[1., 0], [0., 1]]
- dist = tf.contrib.distributions.DirichletMultinomial([1], alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial([1.], alpha)
pmf = dist.pmf(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
@@ -231,12 +233,12 @@ class DirichletMultinomialTest(tf.test.TestCase):
def testVariance_n_alpha_broadcast(self):
alpha_v = [1., 2, 3]
- alpha_0 = np.sum(alpha_v)
+ alpha_0 = 6.
# Shape [4, 3]
- alpha = np.array(4 * [alpha_v])
+ alpha = np.array(4 * [alpha_v], dtype=np.float32)
# Shape [4, 1]
- ns = np.array([[2.], [3.], [4.], [5.]])
+ ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
@@ -250,7 +252,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
covariance_entry(alpha_v[1], alpha_v[2], alpha_0)],
[covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
- variance_entry(alpha_v[2], alpha_0)]]])
+ variance_entry(alpha_v[2], alpha_0)]]], dtype=np.float32)
with self.test_session():
# ns is shape [4, 1], and alpha is shape [4, 3].
@@ -263,11 +265,11 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertAllClose(expected_variance, variance.eval())
def testVariance_multidimensional(self):
- alpha = np.random.rand(3, 5, 4)
- alpha2 = np.random.rand(6, 3, 3)
- # Ensure n > 0.
- ns = np.random.geometric(p=0.8, size=[3, 5, 1]) + 1
- ns2 = np.random.geometric(p=0.8, size=[6, 1, 1]) + 1
+ alpha = np.random.rand(3, 5, 4).astype(np.float32)
+ alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
+
+ ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
+ ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
@@ -297,7 +299,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# One (three sided) coin flip. Prob[coin 3] = 0.8.
# Note that since it was one flip, value of tau didn't matter.
- counts = [0, 0, 1]
+ counts = [0., 0, 1]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
pmf = dist.pmf(counts)
@@ -305,9 +307,9 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertEqual((), pmf.get_shape())
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
- counts = [0, 0, 2]
+ counts = [0., 0, 2]
with self.test_session():
- dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
@@ -315,7 +317,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
# Three (three sided) coin flips.
counts = [1., 0, 2]
with self.test_session():
- dist = tf.contrib.distributions.DirichletMultinomial(3, alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial(3., alpha)
pmf = dist.pmf(counts)
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
@@ -336,10 +338,10 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertEqual((), pmf.get_shape())
# If there are two draws, it is much more likely that they are the same.
- counts_same = [2, 0]
+ counts_same = [2., 0]
counts_different = [1, 1.]
with self.test_session():
- dist = tf.contrib.distributions.DirichletMultinomial(2, alpha)
+ dist = tf.contrib.distributions.DirichletMultinomial(2., alpha)
pmf_same = dist.pmf(counts_same)
pmf_different = dist.pmf(counts_different)
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py
index b3259b2867..fe5826e491 100644
--- a/tensorflow/contrib/distributions/python/ops/bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py
@@ -20,12 +20,14 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
@@ -89,9 +91,11 @@ class Bernoulli(distribution.Distribution):
elif logits is None:
with ops.name_scope(name):
with ops.name_scope("p"):
- with ops.control_dependencies([check_op(p, 1.), check_op(0., p)] if
- validate_args else []):
- self._p = array_ops.identity(p)
+ p = array_ops.identity(p)
+ one = constant_op.constant(1., p.dtype)
+ zero = constant_op.constant(0., p.dtype)
+ self._p = control_flow_ops.with_dependencies(
+ [check_op(p, one), check_op(zero, p)] if validate_args else [], p)
with ops.name_scope("logits"):
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
with ops.name_scope(name):
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
index 6982a73381..7c779fff06 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
@@ -132,25 +132,21 @@ class DirichletMultinomial(distribution.Distribution):
def __init__(self,
n,
alpha,
- allow_arbitrary_counts=False,
validate_args=True,
allow_nan_stats=False,
name='DirichletMultinomial'):
"""Initialize a batch of DirichletMultinomial distributions.
Args:
- n: Non-negative `float` or `double` tensor with shape
- broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch
- of `N1 x ... x Nm` different Dirichlet multinomial distributions. Its
- components should be equal to integral values.
- alpha: Positive `float` or `double` tensor with shape broadcastable to
- `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
- different `k` class Dirichlet multinomial distributions.
- allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf
- allows for the `counts` tensor to be non-integral values.
- The pmf/cdf are functions that can be evaluated at non-integral values,
- but are only a distribution over non-negative integers. If
- `validate_args` is `False`, this assertion is turned off.
+ n: Non-negative `float` or `double` tensor, whose dtype is the same as
+ `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
+ Defines this as a batch of `N1 x ... x Nm` different Dirichlet
+ multinomial distributions. Its components should be equal to integral
+ values.
+ alpha: Positive `float` or `double` tensor, whose dtype is the same as
+ `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
+ this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
+ multinomial distributions.
validate_args: Whether to assert valid values for parameters `alpha` and
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is
not guaranteed.
@@ -174,7 +170,6 @@ class DirichletMultinomial(distribution.Distribution):
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
self._name = name
- self._allow_arbitrary_counts = allow_arbitrary_counts
with ops.op_scope([n, alpha], name):
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
@@ -186,8 +181,7 @@ class DirichletMultinomial(distribution.Distribution):
# * All calls involving `counts` eventually require a broadcast between
# `counts` and alpha.
self._alpha = self._check_alpha(alpha)
- n = self._check_n(n)
- self._n = math_ops.cast(n, self._alpha.dtype)
+ self._n = self._check_n(n)
self._alpha_sum = math_ops.reduce_sum(
self._alpha, reduction_indices=[-1], keep_dims=False)
@@ -346,12 +340,12 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient.
Args:
- counts: Non-negative `float` or `double` tensor whose shape can
- be broadcast with `self.alpha`. For fixed leading dimensions, the last
- dimension represents counts for the corresponding Dirichlet Multinomial
- distribution in `self.alpha`. `counts` is only legal if it sums up to
- `n` and its components are equal to integral values. The second
- condition is relaxed if `allow_arbitrary_counts` is set.
+ counts: Non-negative `float` or `double` tensor whose dtype is the same
+ `self` and whose shape can be broadcast with `self.alpha`. For fixed
+ leading dimensions, the last dimension represents counts for the
+ corresponding Dirichlet Multinomial distribution in `self.alpha`.
+ `counts` is only legal if it sums up to `n` and its components are
+ equal to integral values.
name: Name to give this Op, defaults to "log_prob".
Returns:
@@ -362,8 +356,6 @@ class DirichletMultinomial(distribution.Distribution):
with ops.name_scope(self.name):
with ops.op_scope([n, alpha, counts], name):
counts = self._check_counts(counts)
- # Use the same dtype as alpha for computations.
- counts = math_ops.cast(counts, self.dtype)
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
special_math_ops.lbeta(alpha))
@@ -390,12 +382,12 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient.
Args:
- counts: Non-negative `float`, `double` tensor whose shape can
- be broadcast with `self.alpha`. For fixed leading dimensions, the last
- dimension represents counts for the corresponding Dirichlet Multinomial
- distribution in `self.alpha`. `counts` is only legal if it sums up to
- `n` and its components are equal to integral values. The second
- condition is relaxed if `allow_arbitrary_counts` is set.
+ counts: Non-negative `float` or `double` tensor whose dtype is the same
+ `self` and whose shape can be broadcast with `self.alpha`. For fixed
+ leading dimensions, the last dimension represents counts for the
+ corresponding Dirichlet Multinomial distribution in `self.alpha`.
+ `counts` is only legal if it sums up to `n` and its components are
+ equal to integral values.
name: Name to give this Op, defaults to "prob".
Returns:
@@ -409,14 +401,11 @@ class DirichletMultinomial(distribution.Distribution):
if not self.validate_args:
return counts
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
- dependencies = [check_ops.assert_non_negative(counts),
- check_ops.assert_equal(self._n,
- math_ops.cast(candidate_n,
- self._n.dtype))]
- if not self._allow_arbitrary_counts:
- dependencies += [_assert_integer_form(counts)]
-
- return control_flow_ops.with_dependencies(dependencies, counts)
+
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_non_negative(counts),
+ check_ops.assert_equal(self._n, candidate_n),
+ _assert_integer_form(counts)], counts)
def _check_alpha(self, alpha):
alpha = ops.convert_to_tensor(alpha, name='alpha')
diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py
index 13b26a11db..c49b3eeba8 100644
--- a/tensorflow/contrib/distributions/python/ops/exponential.py
+++ b/tensorflow/contrib/distributions/python/ops/exponential.py
@@ -102,6 +102,7 @@ class Exponential(gamma.Gamma):
shape, minval=np.nextafter(
self.dtype.as_numpy_dtype(0.), self.dtype.as_numpy_dtype(1.)),
maxval=constant_op.constant(1.0, dtype=self.dtype),
+ seed=seed,
dtype=self.dtype)
n_val = tensor_util.constant_value(n)