aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-01-03 11:51:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-03 12:05:45 -0800
commit1514d36258256ad535e88dc7cc7b9e5b136f4270 (patch)
tree408fdb73a8683a260ee900f577897d9560843527
parent6a04c105a2cf5ecc6981040b3943668edd0dc2e4 (diff)
Fix `sample` shape hints and remove `sample_n`.
Change: 143469030
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py6
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/entropy.py6
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py91
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py45
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py87
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution.py161
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py2
10 files changed, 217 insertions, 199 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 11528da9a3..89795af9e6 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -132,7 +132,7 @@ class ExpectationTest(test.TestCase):
with self.test_session():
p = distributions.Normal(mu=[1.0, -1.0], sigma=[0.3, 0.5])
# Compute E_p[X] and E_p[X^2].
- z = p.sample_n(n=n)
+ z = p.sample(n, seed=42)
e_x = monte_carlo.expectation(lambda x: x, p, z=z, seed=42)
e_x2 = monte_carlo.expectation(math_ops.square, p, z=z, seed=0)
var = e_x2 - math_ops.square(e_x)
@@ -161,7 +161,7 @@ class GetSamplesTest(test.TestCase):
def test_raises_if_both_z_and_n_are_not_none(self):
with self.test_session():
dist = distributions.Normal(mu=0., sigma=1.)
- z = dist.sample_n(n=1)
+ z = dist.sample(seed=42)
n = 1
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
@@ -179,7 +179,7 @@ class GetSamplesTest(test.TestCase):
def test_returns_z_if_z_provided(self):
with self.test_session():
dist = distributions.Normal(mu=0., sigma=1.)
- z = dist.sample_n(n=10)
+ z = dist.sample(10, seed=42)
n = None
seed = None
z = monte_carlo._get_samples(dist, z, n, seed)
diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy.py b/tensorflow/contrib/bayesflow/python/ops/entropy.py
index 80b35c59d2..56490c390c 100644
--- a/tensorflow/contrib/bayesflow/python/ops/entropy.py
+++ b/tensorflow/contrib/bayesflow/python/ops/entropy.py
@@ -143,7 +143,7 @@ def elbo_ratio(log_p,
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `q.log_prob`.
q: `tf.contrib.distributions.Distribution`.
- z: `Tensor` of samples from `q`, produced by `q.sample_n`.
+ z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
@@ -193,7 +193,7 @@ def entropy_shannon(p,
Args:
p: `tf.contrib.distributions.Distribution`
- z: `Tensor` of samples from `p`, produced by `p.sample_n(n)` for some `n`.
+ z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
@@ -326,7 +326,7 @@ def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'):
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1.
- z: `Tensor` of samples from `q`, produced by `q.sample_n`.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
n: Integer `Tensor`. The number of samples to use if `z` is not provided.
Note that this can be highly biased for small `n`, see docstring.
seed: Python integer to seed the random number generator.
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
index 8a03c348fa..198d755dff 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
@@ -118,7 +118,7 @@ def expectation_importance_sampler(f,
`tf.contrib.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
- z: `Tensor` of samples from `q`, produced by `q.sample_n`.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
name: A name to give this `Op`.
@@ -195,7 +195,7 @@ def expectation_importance_sampler_logspace(
`tf.contrib.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
- z: `Tensor` of samples from `q`, produced by `q.sample_n`.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
name: A name to give this `Op`.
@@ -254,7 +254,7 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
Args:
f: Callable mapping samples from `p` to `Tensors`.
p: `tf.contrib.distributions.Distribution`.
- z: `Tensor` of samples from `p`, produced by `p.sample_n`.
+ z: `Tensor` of samples from `p`, produced by `p.sample` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
name: A name to give this `Op`.
@@ -314,6 +314,6 @@ def _get_samples(dist, z, n, seed):
'Must specify exactly one of arguments "n" and "z". Found: '
'n = %s, z = %s' % (n, z))
if n is not None:
- return dist.sample_n(n=n, seed=seed)
+ return dist.sample(n, seed=seed)
else:
return ops.convert_to_tensor(z, name='z')
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index e5a1de4bbb..4f22529d6b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -18,22 +18,30 @@ from __future__ import print_function
from tensorflow.contrib import distributions
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-dists = distributions
+ds = distributions
class DistributionTest(test.TestCase):
def testParamShapesAndFromParams(self):
classes = [
- dists.Normal, dists.Bernoulli, dists.Beta, dists.Chi2,
- dists.Exponential, dists.Gamma, dists.InverseGamma, dists.Laplace,
- dists.StudentT, dists.Uniform
+ ds.Normal,
+ ds.Bernoulli,
+ ds.Beta,
+ ds.Chi2,
+ ds.Exponential,
+ ds.Gamma,
+ ds.InverseGamma,
+ ds.Laplace,
+ ds.StudentT,
+ ds.Uniform,
]
sample_shapes = [(), (10,), (10, 20, 30)]
@@ -55,15 +63,15 @@ class DistributionTest(test.TestCase):
with self.test_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
- normal = dists.Normal(mu=1., sigma=2., validate_args=True)
+ normal = ds.Normal(mu=1., sigma=2., validate_args=True)
self.assertEqual(normal.parameters, normal.copy().parameters)
- wishart = dists.WishartFull(
- df=2, scale=[[1., 2], [2, 5]], validate_args=True)
+ wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]],
+ validate_args=True)
self.assertEqual(wishart.parameters, wishart.copy().parameters)
def testCopyOverride(self):
with self.test_session():
- normal = dists.Normal(mu=1., sigma=2., validate_args=True)
+ normal = ds.Normal(mu=1., sigma=2., validate_args=True)
normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
copy_params = normal.copy(validate_args=False).parameters.copy()
@@ -76,19 +84,19 @@ class DistributionTest(test.TestCase):
mu = 1.
sigma = 2.
- normal = dists.Normal(mu, sigma, validate_args=True)
+ normal = ds.Normal(mu, sigma, validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch))
- normal = dists.Normal([mu], [sigma], validate_args=True)
+ normal = ds.Normal([mu], [sigma], validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch))
- mvn = dists.MultivariateNormalDiag([mu], [sigma], validate_args=True)
+ mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch))
- mvn = dists.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
+ mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch))
@@ -117,6 +125,65 @@ class DistributionTest(test.TestCase):
self.assertTrue(is_scalar.eval(feed_dict={x: 1}))
self.assertFalse(is_scalar.eval(feed_dict={x: [1]}))
+ def testSampleShapeHints(self):
+ class _FakeDistribution(ds.Distribution):
+ """Fake Distribution for testing _set_sample_static_shape."""
+
+ def __init__(self, batch_shape=None, event_shape=None):
+ self._static_batch_shape = tensor_shape.TensorShape(batch_shape)
+ self._static_event_shape = tensor_shape.TensorShape(event_shape)
+ super(_FakeDistribution, self).__init__(
+ dtype=dtypes.float32,
+ is_continuous=False,
+ is_reparameterized=False,
+ validate_args=True,
+ allow_nan_stats=True,
+ name="DummyDistribution")
+
+ def _get_batch_shape(self):
+ return self._static_batch_shape
+
+ def _get_event_shape(self):
+ return self._static_event_shape
+
+ with self.test_session():
+ # Make a new session since we're playing with static shapes. [And below.]
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ dist = _FakeDistribution(batch_shape=[2, 3], event_shape=[5])
+ sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
+ y = dist._set_sample_static_shape(x, sample_shape)
+ # We use as_list since TensorShape comparison does not work correctly for
+ # unknown values, ie, Dimension(None).
+ self.assertAllEqual([6, 7, 2, 3, 5], y.get_shape().as_list())
+
+ with self.test_session():
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ dist = _FakeDistribution(batch_shape=[None, 3], event_shape=[5])
+ sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
+ y = dist._set_sample_static_shape(x, sample_shape)
+ self.assertAllEqual([6, 7, None, 3, 5], y.get_shape().as_list())
+
+ with self.test_session():
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ dist = _FakeDistribution(batch_shape=[None, 3], event_shape=[None])
+ sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
+ y = dist._set_sample_static_shape(x, sample_shape)
+ self.assertAllEqual([6, 7, None, 3, None], y.get_shape().as_list())
+
+ with self.test_session():
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ dist = _FakeDistribution(batch_shape=None, event_shape=None)
+ sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
+ y = dist._set_sample_static_shape(x, sample_shape)
+ self.assertTrue(y.get_shape().ndims is None)
+
+ with self.test_session():
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ dist = _FakeDistribution(batch_shape=[None, 3], event_shape=None)
+ sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
+ y = dist._set_sample_static_shape(x, sample_shape)
+ self.assertTrue(y.get_shape().ndims is None)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 6e72f1ca31..aff34f48be 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -44,34 +44,35 @@ def _swap_first_last_axes(array):
@contextlib.contextmanager
def _test_capture_mvndiag_sample_outputs():
- """Use monkey-patching to capture the output of an MVNDiag sample_n."""
+ """Use monkey-patching to capture the output of an MVNDiag _sample_n."""
data_container = []
- true_mvndiag_sample = distributions_py.MultivariateNormalDiag.sample_n
+ true_mvndiag_sample_n = distributions_py.MultivariateNormalDiag._sample_n
- def _capturing_mvndiag_sample(self, n, seed=None, name="sample_n"):
- samples = true_mvndiag_sample(self, n=n, seed=seed, name=name)
+ def _capturing_mvndiag_sample_n(self, n, seed=None):
+ samples = true_mvndiag_sample_n(self, n=n, seed=seed)
data_container.append(samples)
return samples
- distributions_py.MultivariateNormalDiag.sample_n = _capturing_mvndiag_sample
+ distributions_py.MultivariateNormalDiag._sample_n = (
+ _capturing_mvndiag_sample_n)
yield data_container
- distributions_py.MultivariateNormalDiag.sample_n = true_mvndiag_sample
+ distributions_py.MultivariateNormalDiag._sample_n = true_mvndiag_sample_n
@contextlib.contextmanager
def _test_capture_normal_sample_outputs():
- """Use monkey-patching to capture the output of an Normal sample_n."""
+ """Use monkey-patching to capture the output of an Normal _sample_n."""
data_container = []
- true_normal_sample = distributions_py.Normal.sample_n
+ true_normal_sample_n = distributions_py.Normal._sample_n
- def _capturing_normal_sample(self, n, seed=None, name="sample_n"):
- samples = true_normal_sample(self, n=n, seed=seed, name=name)
+ def _capturing_normal_sample_n(self, n, seed=None):
+ samples = true_normal_sample_n(self, n=n, seed=seed)
data_container.append(samples)
return samples
- distributions_py.Normal.sample_n = _capturing_normal_sample
+ distributions_py.Normal._sample_n = _capturing_normal_sample_n
yield data_container
- distributions_py.Normal.sample_n = true_normal_sample
+ distributions_py.Normal._sample_n = true_normal_sample_n
def make_univariate_mixture(batch_shape, num_components):
@@ -346,10 +347,10 @@ class MixtureTest(test.TestCase):
batch_shape=[], num_components=num_components)
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
- samples = dist.sample_n(n, seed=123)
+ samples = dist.sample(n, seed=123)
self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4,), samples.get_shape())
- cat_samples = dist.cat.sample_n(n, seed=123)
+ cat_samples = dist.cat.sample(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
[samples, cat_samples, component_samples])
self.assertEqual((4,), sample_values.shape)
@@ -379,7 +380,7 @@ class MixtureTest(test.TestCase):
cat = distributions_py.Categorical(
logits, dtype=dtypes.int32, name="cat1")
dist1 = distributions_py.Mixture(cat, components, name="mixture1")
- samples1 = dist1.sample_n(n, seed=123456).eval()
+ samples1 = dist1.sample(n, seed=123456).eval()
random_seed.set_random_seed(654321)
components2 = [
@@ -389,7 +390,7 @@ class MixtureTest(test.TestCase):
cat2 = distributions_py.Categorical(
logits, dtype=dtypes.int32, name="cat2")
dist2 = distributions_py.Mixture(cat2, components2, name="mixture2")
- samples2 = dist2.sample_n(n, seed=123456).eval()
+ samples2 = dist2.sample(n, seed=123456).eval()
self.assertAllClose(samples1, samples2)
@@ -400,10 +401,10 @@ class MixtureTest(test.TestCase):
batch_shape=[], num_components=num_components, event_shape=[2])
n = 4
with _test_capture_mvndiag_sample_outputs() as component_samples:
- samples = dist.sample_n(n, seed=123)
+ samples = dist.sample(n, seed=123)
self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4, 2), samples.get_shape())
- cat_samples = dist.cat.sample_n(n, seed=123)
+ cat_samples = dist.cat.sample(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
[samples, cat_samples, component_samples])
self.assertEqual((4, 2), sample_values.shape)
@@ -421,10 +422,10 @@ class MixtureTest(test.TestCase):
batch_shape=[2, 3], num_components=num_components)
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
- samples = dist.sample_n(n, seed=123)
+ samples = dist.sample(n, seed=123)
self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4, 2, 3), samples.get_shape())
- cat_samples = dist.cat.sample_n(n, seed=123)
+ cat_samples = dist.cat.sample(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
[samples, cat_samples, component_samples])
self.assertEqual((4, 2, 3), sample_values.shape)
@@ -444,10 +445,10 @@ class MixtureTest(test.TestCase):
batch_shape=[2, 3], num_components=num_components, event_shape=[4])
n = 5
with _test_capture_mvndiag_sample_outputs() as component_samples:
- samples = dist.sample_n(n, seed=123)
+ samples = dist.sample(n, seed=123)
self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((5, 2, 3, 4), samples.get_shape())
- cat_samples = dist.cat.sample_n(n, seed=123)
+ cat_samples = dist.cat.sample(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
[samples, cat_samples, component_samples])
self.assertEqual((5, 2, 3, 4), sample_values.shape)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
index 06977350e7..1fd35d682f 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
@@ -179,7 +179,7 @@ class QuantizedDistributionTest(test.TestCase):
qdist = distributions.QuantizedDistribution(
distribution=normal, lower_cutoff=0., upper_cutoff=None)
- samps = qdist.sample_n(n=5000, seed=42)
+ samps = qdist.sample(5000, seed=42)
samps_v = samps.eval()
# With lower_cutoff = 0, the interval j=0 is (-infty, 0], which holds 1/2
@@ -207,7 +207,7 @@ class QuantizedDistributionTest(test.TestCase):
qdist = distributions.QuantizedDistribution(
distribution=distributions.Exponential(lam=0.01))
# X ~ QuantizedExponential
- x = qdist.sample_n(n=10000, seed=42)
+ x = qdist.sample(10000, seed=42)
# Z = F(X), should be Uniform.
z = qdist.cdf(x)
# Compare the CDF of Z to that of a Uniform.
@@ -419,7 +419,7 @@ class QuantizedDistributionTest(test.TestCase):
self.assertEqual((), qdist.get_event_shape())
self.assertAllEqual((), qdist.event_shape().eval())
- samps = qdist.sample_n(n=10)
+ samps = qdist.sample(10, seed=42)
self.assertEqual((10,) + batch_shape, samps.get_shape())
self.assertAllEqual((10,) + batch_shape, samps.eval().shape)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
index 8596c3246b..73e01f9fef 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py
@@ -245,30 +245,11 @@ class StudentTTest(test.TestCase):
self.assertEqual(student.entropy().get_shape(), (3,))
self.assertEqual(student.log_pdf(2.).get_shape(), (3,))
self.assertEqual(student.pdf(2.).get_shape(), (3,))
- self.assertEqual(
- student.sample(
- 37, seed=123456).get_shape(), (
- 37,
- 3,))
-
- _check(ds.StudentT(
- df=[
- 2.,
- 3.,
- 4.,
- ], mu=2., sigma=1.))
- _check(ds.StudentT(
- df=7., mu=[
- 2.,
- 3.,
- 4.,
- ], sigma=1.))
- _check(ds.StudentT(
- df=7., mu=3., sigma=[
- 2.,
- 3.,
- 4.,
- ]))
+ self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
+
+ _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
+ _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
+ _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
def testBroadcastingPdfArgs(self):
@@ -285,24 +266,9 @@ class StudentTTest(test.TestCase):
xs = xs.T
_assert_shape(student, xs, (3, 3))
- _check(ds.StudentT(
- df=[
- 2.,
- 3.,
- 4.,
- ], mu=2., sigma=1.))
- _check(ds.StudentT(
- df=7., mu=[
- 2.,
- 3.,
- 4.,
- ], sigma=1.))
- _check(ds.StudentT(
- df=7., mu=3., sigma=[
- 2.,
- 3.,
- 4.,
- ]))
+ _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
+ _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
+ _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
def _check2d(student):
_assert_shape(student, 2., (1, 3))
@@ -313,24 +279,9 @@ class StudentTTest(test.TestCase):
xs = xs.T
_assert_shape(student, xs, (3, 3))
- _check2d(ds.StudentT(
- df=[[
- 2.,
- 3.,
- 4.,
- ]], mu=2., sigma=1.))
- _check2d(ds.StudentT(
- df=7., mu=[[
- 2.,
- 3.,
- 4.,
- ]], sigma=1.))
- _check2d(ds.StudentT(
- df=7., mu=3., sigma=[[
- 2.,
- 3.,
- 4.,
- ]]))
+ _check2d(ds.StudentT(df=[[2., 3., 4.,]], mu=2., sigma=1.))
+ _check2d(ds.StudentT(df=7., mu=[[2., 3., 4.,]], sigma=1.))
+ _check2d(ds.StudentT(df=7., mu=3., sigma=[[2., 3., 4.,]]))
def _check2d_rows(student):
_assert_shape(student, 2., (3, 1))
@@ -355,8 +306,8 @@ class StudentTTest(test.TestCase):
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
with self.test_session():
mu = [1., 3.3, 4.4]
- student = ds.StudentT(
- df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.], allow_nan_stats=False)
+ student = ds.StudentT(df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.],
+ allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
student.mean().eval()
@@ -364,8 +315,8 @@ class StudentTTest(test.TestCase):
with self.test_session():
mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.]
- student = ds.StudentT(
- df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma, allow_nan_stats=True)
+ student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma,
+ allow_nan_stats=True)
mean = student.mean().eval()
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
@@ -503,15 +454,15 @@ class StudentTTest(test.TestCase):
def testNegativeDofFails(self):
with self.test_session():
- student = ds.StudentT(
- df=[2, -5.], mu=0., sigma=1., validate_args=True, name="S")
+ student = ds.StudentT(df=[2, -5.], mu=0., sigma=1.,
+ validate_args=True, name="S")
with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
student.mean().eval()
def testNegativeScaleFails(self):
with self.test_session():
- student = ds.StudentT(
- df=[5.], mu=0., sigma=[[3.], [-2.]], validate_args=True, name="S")
+ student = ds.StudentT(df=[5.], mu=0., sigma=[[3.], [-2.]],
+ validate_args=True, name="S")
with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
student.mean().eval()
diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py
index f0519ca79d..eb1c985290 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution.py
@@ -39,7 +39,7 @@ from tensorflow.python.ops import math_ops
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
"batch_shape", "get_batch_shape", "event_shape", "get_event_shape",
- "sample_n", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function",
+ "sample", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function",
"survival_function", "entropy", "mean", "variance", "std", "mode"]
@@ -562,61 +562,14 @@ class Distribution(_BaseDistribution):
with self._name_scope(name, values=[sample_shape]):
sample_shape = ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32, name="sample_shape")
- if sample_shape.get_shape().ndims == 0:
- return self.sample_n(sample_shape, seed, **condition_kwargs)
- sample_shape, total = self._expand_sample_shape(sample_shape)
- samples = self.sample_n(total, seed, **condition_kwargs)
- output_shape = array_ops.concat_v2(
- [sample_shape, array_ops.slice(array_ops.shape(samples), [1], [-1])],
- 0)
- output = array_ops.reshape(samples, output_shape)
- output.set_shape(tensor_util.constant_value_as_shape(
- sample_shape).concatenate(samples.get_shape()[1:]))
- return output
-
- def sample_n(self, n, seed=None, name="sample_n", **condition_kwargs):
- """Generate `n` samples.
-
- Args:
- n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
- observations to sample.
- seed: Python integer seed for RNG
- name: name to give to the op.
- **condition_kwargs: Named arguments forwarded to subclass implementation.
-
- Returns:
- samples: a `Tensor` with a prepended dimension (n,).
-
- Raises:
- TypeError: if `n` is not an integer type.
- """
- warnings.warn("Please use `sample` instead of `sample_n`. `sample_n` "
- "will be deprecated in December 2016.",
- PendingDeprecationWarning)
- with self._name_scope(name, values=[n]):
- n = ops.convert_to_tensor(n, name="n")
- if not n.dtype.is_integer:
- raise TypeError("n.dtype=%s is not an integer type" % n.dtype)
- x = self._sample_n(n, seed, **condition_kwargs)
-
- # Set shape hints.
- sample_shape = tensor_shape.TensorShape(
- tensor_util.constant_value(n))
- batch_ndims = self.get_batch_shape().ndims
- event_ndims = self.get_event_shape().ndims
- if batch_ndims is not None and event_ndims is not None:
- inferred_shape = sample_shape.concatenate(
- self.get_batch_shape().concatenate(
- self.get_event_shape()))
- x.set_shape(inferred_shape)
- elif x.get_shape().ndims is not None and x.get_shape().ndims > 0:
- x.get_shape()[0].merge_with(sample_shape[0])
- if batch_ndims is not None and batch_ndims > 0:
- x.get_shape()[1:1+batch_ndims].merge_with(self.get_batch_shape())
- if event_ndims is not None and event_ndims > 0:
- x.get_shape()[-event_ndims:].merge_with(self.get_event_shape())
-
- return x
+ sample_shape, n = self._expand_sample_shape_to_vector(
+ sample_shape, "sample_shape")
+ samples = self._sample_n(n, seed, **condition_kwargs)
+ batch_event_shape = array_ops.shape(samples)[1:]
+ final_shape = array_ops.concat_v2([sample_shape, batch_event_shape], 0)
+ samples = array_ops.reshape(samples, final_shape)
+ samples = self._set_sample_static_shape(samples, sample_shape)
+ return samples
def _log_prob(self, value):
raise NotImplementedError("log_prob is not implemented")
@@ -938,33 +891,79 @@ class Distribution(_BaseDistribution):
(values or []) + self._graph_parents)) as scope:
yield scope
- def _expand_sample_shape(self, sample_shape):
- """Helper to `sample` which ensures sample_shape is 1D."""
- sample_shape_static_val = tensor_util.constant_value(sample_shape)
- ndims = sample_shape.get_shape().ndims
- if sample_shape_static_val is None:
- if ndims is None or not sample_shape.get_shape().is_fully_defined():
- ndims = array_ops.rank(sample_shape)
+ def _expand_sample_shape_to_vector(self, x, name):
+ """Helper to `sample` which ensures input is 1D."""
+ x_static_val = tensor_util.constant_value(x)
+ if x_static_val is None:
+ prod = math_ops.reduce_prod(x)
+ else:
+ prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
+
+ ndims = x.get_shape().ndims # != sample_ndims
+ if ndims is None:
+ # Maybe expand_dims.
+ ndims = array_ops.rank(x)
expanded_shape = distribution_util.pick_vector(
math_ops.equal(ndims, 0),
- np.array((1,), dtype=dtypes.int32.as_numpy_dtype()),
- array_ops.shape(sample_shape))
- sample_shape = array_ops.reshape(sample_shape, expanded_shape)
- total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1
- else:
- if ndims is None:
- raise ValueError(
- "Shouldn't be here; ndims cannot be none when we have a "
- "tf.constant shape.")
- if ndims == 0:
- sample_shape_static_val = np.reshape(sample_shape_static_val, [1])
- sample_shape = ops.convert_to_tensor(
- sample_shape_static_val,
- dtype=dtypes.int32,
- name="sample_shape")
- total = np.prod(sample_shape_static_val,
- dtype=dtypes.int32.as_numpy_dtype())
- return sample_shape, total
+ np.array([1], dtype=np.int32),
+ array_ops.shape(x))
+ x = array_ops.reshape(x, expanded_shape)
+ elif ndims == 0:
+ # Definitely expand_dims.
+ if x_static_val is not None:
+ x = ops.convert_to_tensor(
+ np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
+ name=name)
+ else:
+ x = array_ops.reshape(x, [1])
+ elif ndims != 1:
+ raise ValueError("Input is neither scalar nor vector.")
+
+ return x, prod
+
+ def _set_sample_static_shape(self, x, sample_shape):
+ """Helper to `sample`; sets static shape info."""
+ # Set shape hints.
+ sample_shape = tensor_shape.TensorShape(
+ tensor_util.constant_value(sample_shape))
+
+ ndims = x.get_shape().ndims
+ sample_ndims = sample_shape.ndims
+ batch_ndims = self.get_batch_shape().ndims
+ event_ndims = self.get_event_shape().ndims
+
+ # Infer rank(x).
+ if (ndims is None and
+ sample_ndims is not None and
+ batch_ndims is not None and
+ event_ndims is not None):
+ ndims = sample_ndims + batch_ndims + event_ndims
+ x.set_shape([None] * ndims)
+
+ # Infer sample shape.
+ if ndims is not None and sample_ndims is not None:
+ shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
+ x.set_shape(x.get_shape().merge_with(shape))
+
+ # Infer event shape.
+ if ndims is not None and event_ndims is not None:
+ shape = tensor_shape.TensorShape(
+ [None]*(ndims - event_ndims)).concatenate(self.get_event_shape())
+ x.set_shape(x.get_shape().merge_with(shape))
+
+ # Infer batch shape.
+ if batch_ndims is not None:
+ if ndims is not None:
+ if sample_ndims is None and event_ndims is not None:
+ sample_ndims = ndims - batch_ndims - event_ndims
+ elif event_ndims is None and sample_ndims is not None:
+ event_ndims = ndims - batch_ndims - sample_ndims
+ if sample_ndims is not None and event_ndims is not None:
+ shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
+ self.get_batch_shape()).concatenate([None]*event_ndims)
+ x.set_shape(x.get_shape().merge_with(shape))
+
+ return x
def _is_scalar_helper(self, static_shape_fn, dynamic_shape_fn):
"""Implementation for `is_scalar_batch` and `is_scalar_event`."""
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 95ad9fe06a..0e98e9e3b0 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -244,7 +244,7 @@ class Mixture(distribution.Distribution):
n = ops.convert_to_tensor(n, name="n")
static_n = tensor_util.constant_value(n)
n = int(static_n) if static_n is not None else n
- cat_samples = self.cat.sample_n(n, seed=seed)
+ cat_samples = self.cat.sample(n, seed=seed)
static_samples_shape = cat_samples.get_shape()
if static_samples_shape.is_fully_defined():
@@ -308,7 +308,7 @@ class Mixture(distribution.Distribution):
for c in range(self.num_components):
n_class = array_ops.size(partitioned_samples_indices[c])
seed = distribution_util.gen_new_seed(seed, "mixture")
- samples_class_c = self.components[c].sample_n(n_class, seed=seed)
+ samples_class_c = self.components[c].sample(n_class, seed=seed)
# Pull out the correct batch entries from each index.
# To do this, we may have to flatten the batch shape.
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index fd3ec553c0..713f09ccf8 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -277,7 +277,7 @@ class QuantizedDistribution(distributions.Distribution):
upper_cutoff = self._upper_cutoff
with ops.name_scope("transform"):
n = ops.convert_to_tensor(n, name="n")
- x_samps = self.distribution.sample_n(n=n, seed=seed)
+ x_samps = self.distribution.sample(n, seed=seed)
ones = array_ops.ones_like(x_samps)
# Snap values to the intervals (j - 1, j].