From 1514d36258256ad535e88dc7cc7b9e5b136f4270 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Tue, 3 Jan 2017 11:51:14 -0800 Subject: Fix `sample` shape hints and remove `sample_n`. Change: 143469030 --- .../python/kernel_tests/monte_carlo_test.py | 6 +- tensorflow/contrib/bayesflow/python/ops/entropy.py | 6 +- .../contrib/bayesflow/python/ops/monte_carlo.py | 8 +- .../python/kernel_tests/distribution_test.py | 91 ++++++++++-- .../python/kernel_tests/mixture_test.py | 45 +++--- .../kernel_tests/quantized_distribution_test.py | 6 +- .../python/kernel_tests/student_t_test.py | 87 +++-------- .../distributions/python/ops/distribution.py | 161 ++++++++++----------- .../contrib/distributions/python/ops/mixture.py | 4 +- .../python/ops/quantized_distribution.py | 2 +- 10 files changed, 217 insertions(+), 199 deletions(-) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 11528da9a3..89795af9e6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -132,7 +132,7 @@ class ExpectationTest(test.TestCase): with self.test_session(): p = distributions.Normal(mu=[1.0, -1.0], sigma=[0.3, 0.5]) # Compute E_p[X] and E_p[X^2]. - z = p.sample_n(n=n) + z = p.sample(n, seed=42) e_x = monte_carlo.expectation(lambda x: x, p, z=z, seed=42) e_x2 = monte_carlo.expectation(math_ops.square, p, z=z, seed=0) var = e_x2 - math_ops.square(e_x) @@ -161,7 +161,7 @@ class GetSamplesTest(test.TestCase): def test_raises_if_both_z_and_n_are_not_none(self): with self.test_session(): dist = distributions.Normal(mu=0., sigma=1.) - z = dist.sample_n(n=1) + z = dist.sample(seed=42) n = 1 seed = None with self.assertRaisesRegexp(ValueError, 'exactly one'): @@ -179,7 +179,7 @@ class GetSamplesTest(test.TestCase): def test_returns_z_if_z_provided(self): with self.test_session(): dist = distributions.Normal(mu=0., sigma=1.) - z = dist.sample_n(n=10) + z = dist.sample(10, seed=42) n = None seed = None z = monte_carlo._get_samples(dist, z, n, seed) diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy.py b/tensorflow/contrib/bayesflow/python/ops/entropy.py index 80b35c59d2..56490c390c 100644 --- a/tensorflow/contrib/bayesflow/python/ops/entropy.py +++ b/tensorflow/contrib/bayesflow/python/ops/entropy.py @@ -143,7 +143,7 @@ def elbo_ratio(log_p, shape broadcastable to `q.batch_shape`. For example, `log_p` works "just like" `q.log_prob`. q: `tf.contrib.distributions.Distribution`. - z: `Tensor` of samples from `q`, produced by `q.sample_n`. + z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`. n: Integer `Tensor`. Number of samples to generate if `z` is not provided. seed: Python integer to seed the random number generator. form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) @@ -193,7 +193,7 @@ def entropy_shannon(p, Args: p: `tf.contrib.distributions.Distribution` - z: `Tensor` of samples from `p`, produced by `p.sample_n(n)` for some `n`. + z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`. n: Integer `Tensor`. Number of samples to generate if `z` is not provided. seed: Python integer to seed the random number generator. form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) @@ -326,7 +326,7 @@ def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'): `float64` `dtype` recommended. `log_p` and `q` should be supported on the same set. alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1. - z: `Tensor` of samples from `q`, produced by `q.sample_n`. + z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. n: Integer `Tensor`. The number of samples to use if `z` is not provided. Note that this can be highly biased for small `n`, see docstring. seed: Python integer to seed the random number generator. diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 8a03c348fa..198d755dff 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -118,7 +118,7 @@ def expectation_importance_sampler(f, `tf.contrib.distributions.Distribution`. `float64` `dtype` recommended. `log_p` and `q` should be supported on the same set. - z: `Tensor` of samples from `q`, produced by `q.sample_n`. + z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. n: Integer `Tensor`. Number of samples to generate if `z` is not provided. seed: Python integer to seed the random number generator. name: A name to give this `Op`. @@ -195,7 +195,7 @@ def expectation_importance_sampler_logspace( `tf.contrib.distributions.Distribution`. `float64` `dtype` recommended. `log_p` and `q` should be supported on the same set. - z: `Tensor` of samples from `q`, produced by `q.sample_n`. + z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. n: Integer `Tensor`. Number of samples to generate if `z` is not provided. seed: Python integer to seed the random number generator. name: A name to give this `Op`. @@ -254,7 +254,7 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'): Args: f: Callable mapping samples from `p` to `Tensors`. p: `tf.contrib.distributions.Distribution`. - z: `Tensor` of samples from `p`, produced by `p.sample_n`. + z: `Tensor` of samples from `p`, produced by `p.sample` for some `n`. n: Integer `Tensor`. Number of samples to generate if `z` is not provided. seed: Python integer to seed the random number generator. name: A name to give this `Op`. @@ -314,6 +314,6 @@ def _get_samples(dist, z, n, seed): 'Must specify exactly one of arguments "n" and "z". Found: ' 'n = %s, z = %s' % (n, z)) if n is not None: - return dist.sample_n(n=n, seed=seed) + return dist.sample(n, seed=seed) else: return ops.convert_to_tensor(z, name='z') diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index e5a1de4bbb..4f22529d6b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -18,22 +18,30 @@ from __future__ import print_function from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -dists = distributions +ds = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - dists.Normal, dists.Bernoulli, dists.Beta, dists.Chi2, - dists.Exponential, dists.Gamma, dists.InverseGamma, dists.Laplace, - dists.StudentT, dists.Uniform + ds.Normal, + ds.Bernoulli, + ds.Beta, + ds.Chi2, + ds.Exponential, + ds.Gamma, + ds.InverseGamma, + ds.Laplace, + ds.StudentT, + ds.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -55,15 +63,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = dists.Normal(mu=1., sigma=2., validate_args=True) + normal = ds.Normal(mu=1., sigma=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = dists.WishartFull( - df=2, scale=[[1., 2], [2, 5]], validate_args=True) + wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = dists.Normal(mu=1., sigma=2., validate_args=True) + normal = ds.Normal(mu=1., sigma=2., validate_args=True) normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -76,19 +84,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = dists.Normal(mu, sigma, validate_args=True) + normal = ds.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event)) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch)) - normal = dists.Normal([mu], [sigma], validate_args=True) + normal = ds.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event)) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch)) - mvn = dists.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event)) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch)) - mvn = dists.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event)) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch)) @@ -117,6 +125,65 @@ class DistributionTest(test.TestCase): self.assertTrue(is_scalar.eval(feed_dict={x: 1})) self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) + def testSampleShapeHints(self): + class _FakeDistribution(ds.Distribution): + """Fake Distribution for testing _set_sample_static_shape.""" + + def __init__(self, batch_shape=None, event_shape=None): + self._static_batch_shape = tensor_shape.TensorShape(batch_shape) + self._static_event_shape = tensor_shape.TensorShape(event_shape) + super(_FakeDistribution, self).__init__( + dtype=dtypes.float32, + is_continuous=False, + is_reparameterized=False, + validate_args=True, + allow_nan_stats=True, + name="DummyDistribution") + + def _get_batch_shape(self): + return self._static_batch_shape + + def _get_event_shape(self): + return self._static_event_shape + + with self.test_session(): + # Make a new session since we're playing with static shapes. [And below.] + x = array_ops.placeholder(dtype=dtypes.float32) + dist = _FakeDistribution(batch_shape=[2, 3], event_shape=[5]) + sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32) + y = dist._set_sample_static_shape(x, sample_shape) + # We use as_list since TensorShape comparison does not work correctly for + # unknown values, ie, Dimension(None). + self.assertAllEqual([6, 7, 2, 3, 5], y.get_shape().as_list()) + + with self.test_session(): + x = array_ops.placeholder(dtype=dtypes.float32) + dist = _FakeDistribution(batch_shape=[None, 3], event_shape=[5]) + sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32) + y = dist._set_sample_static_shape(x, sample_shape) + self.assertAllEqual([6, 7, None, 3, 5], y.get_shape().as_list()) + + with self.test_session(): + x = array_ops.placeholder(dtype=dtypes.float32) + dist = _FakeDistribution(batch_shape=[None, 3], event_shape=[None]) + sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32) + y = dist._set_sample_static_shape(x, sample_shape) + self.assertAllEqual([6, 7, None, 3, None], y.get_shape().as_list()) + + with self.test_session(): + x = array_ops.placeholder(dtype=dtypes.float32) + dist = _FakeDistribution(batch_shape=None, event_shape=None) + sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32) + y = dist._set_sample_static_shape(x, sample_shape) + self.assertTrue(y.get_shape().ndims is None) + + with self.test_session(): + x = array_ops.placeholder(dtype=dtypes.float32) + dist = _FakeDistribution(batch_shape=[None, 3], event_shape=None) + sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32) + y = dist._set_sample_static_shape(x, sample_shape) + self.assertTrue(y.get_shape().ndims is None) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 6e72f1ca31..aff34f48be 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -44,34 +44,35 @@ def _swap_first_last_axes(array): @contextlib.contextmanager def _test_capture_mvndiag_sample_outputs(): - """Use monkey-patching to capture the output of an MVNDiag sample_n.""" + """Use monkey-patching to capture the output of an MVNDiag _sample_n.""" data_container = [] - true_mvndiag_sample = distributions_py.MultivariateNormalDiag.sample_n + true_mvndiag_sample_n = distributions_py.MultivariateNormalDiag._sample_n - def _capturing_mvndiag_sample(self, n, seed=None, name="sample_n"): - samples = true_mvndiag_sample(self, n=n, seed=seed, name=name) + def _capturing_mvndiag_sample_n(self, n, seed=None): + samples = true_mvndiag_sample_n(self, n=n, seed=seed) data_container.append(samples) return samples - distributions_py.MultivariateNormalDiag.sample_n = _capturing_mvndiag_sample + distributions_py.MultivariateNormalDiag._sample_n = ( + _capturing_mvndiag_sample_n) yield data_container - distributions_py.MultivariateNormalDiag.sample_n = true_mvndiag_sample + distributions_py.MultivariateNormalDiag._sample_n = true_mvndiag_sample_n @contextlib.contextmanager def _test_capture_normal_sample_outputs(): - """Use monkey-patching to capture the output of an Normal sample_n.""" + """Use monkey-patching to capture the output of an Normal _sample_n.""" data_container = [] - true_normal_sample = distributions_py.Normal.sample_n + true_normal_sample_n = distributions_py.Normal._sample_n - def _capturing_normal_sample(self, n, seed=None, name="sample_n"): - samples = true_normal_sample(self, n=n, seed=seed, name=name) + def _capturing_normal_sample_n(self, n, seed=None): + samples = true_normal_sample_n(self, n=n, seed=seed) data_container.append(samples) return samples - distributions_py.Normal.sample_n = _capturing_normal_sample + distributions_py.Normal._sample_n = _capturing_normal_sample_n yield data_container - distributions_py.Normal.sample_n = true_normal_sample + distributions_py.Normal._sample_n = true_normal_sample_n def make_univariate_mixture(batch_shape, num_components): @@ -346,10 +347,10 @@ class MixtureTest(test.TestCase): batch_shape=[], num_components=num_components) n = 4 with _test_capture_normal_sample_outputs() as component_samples: - samples = dist.sample_n(n, seed=123) + samples = dist.sample(n, seed=123) self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4,), samples.get_shape()) - cat_samples = dist.cat.sample_n(n, seed=123) + cat_samples = dist.cat.sample(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( [samples, cat_samples, component_samples]) self.assertEqual((4,), sample_values.shape) @@ -379,7 +380,7 @@ class MixtureTest(test.TestCase): cat = distributions_py.Categorical( logits, dtype=dtypes.int32, name="cat1") dist1 = distributions_py.Mixture(cat, components, name="mixture1") - samples1 = dist1.sample_n(n, seed=123456).eval() + samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) components2 = [ @@ -389,7 +390,7 @@ class MixtureTest(test.TestCase): cat2 = distributions_py.Categorical( logits, dtype=dtypes.int32, name="cat2") dist2 = distributions_py.Mixture(cat2, components2, name="mixture2") - samples2 = dist2.sample_n(n, seed=123456).eval() + samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -400,10 +401,10 @@ class MixtureTest(test.TestCase): batch_shape=[], num_components=num_components, event_shape=[2]) n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: - samples = dist.sample_n(n, seed=123) + samples = dist.sample(n, seed=123) self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4, 2), samples.get_shape()) - cat_samples = dist.cat.sample_n(n, seed=123) + cat_samples = dist.cat.sample(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( [samples, cat_samples, component_samples]) self.assertEqual((4, 2), sample_values.shape) @@ -421,10 +422,10 @@ class MixtureTest(test.TestCase): batch_shape=[2, 3], num_components=num_components) n = 4 with _test_capture_normal_sample_outputs() as component_samples: - samples = dist.sample_n(n, seed=123) + samples = dist.sample(n, seed=123) self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4, 2, 3), samples.get_shape()) - cat_samples = dist.cat.sample_n(n, seed=123) + cat_samples = dist.cat.sample(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( [samples, cat_samples, component_samples]) self.assertEqual((4, 2, 3), sample_values.shape) @@ -444,10 +445,10 @@ class MixtureTest(test.TestCase): batch_shape=[2, 3], num_components=num_components, event_shape=[4]) n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: - samples = dist.sample_n(n, seed=123) + samples = dist.sample(n, seed=123) self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((5, 2, 3, 4), samples.get_shape()) - cat_samples = dist.cat.sample_n(n, seed=123) + cat_samples = dist.cat.sample(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( [samples, cat_samples, component_samples]) self.assertEqual((5, 2, 3, 4), sample_values.shape) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py index 06977350e7..1fd35d682f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py @@ -179,7 +179,7 @@ class QuantizedDistributionTest(test.TestCase): qdist = distributions.QuantizedDistribution( distribution=normal, lower_cutoff=0., upper_cutoff=None) - samps = qdist.sample_n(n=5000, seed=42) + samps = qdist.sample(5000, seed=42) samps_v = samps.eval() # With lower_cutoff = 0, the interval j=0 is (-infty, 0], which holds 1/2 @@ -207,7 +207,7 @@ class QuantizedDistributionTest(test.TestCase): qdist = distributions.QuantizedDistribution( distribution=distributions.Exponential(lam=0.01)) # X ~ QuantizedExponential - x = qdist.sample_n(n=10000, seed=42) + x = qdist.sample(10000, seed=42) # Z = F(X), should be Uniform. z = qdist.cdf(x) # Compare the CDF of Z to that of a Uniform. @@ -419,7 +419,7 @@ class QuantizedDistributionTest(test.TestCase): self.assertEqual((), qdist.get_event_shape()) self.assertAllEqual((), qdist.event_shape().eval()) - samps = qdist.sample_n(n=10) + samps = qdist.sample(10, seed=42) self.assertEqual((10,) + batch_shape, samps.get_shape()) self.assertAllEqual((10,) + batch_shape, samps.eval().shape) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py index 8596c3246b..73e01f9fef 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py @@ -245,30 +245,11 @@ class StudentTTest(test.TestCase): self.assertEqual(student.entropy().get_shape(), (3,)) self.assertEqual(student.log_pdf(2.).get_shape(), (3,)) self.assertEqual(student.pdf(2.).get_shape(), (3,)) - self.assertEqual( - student.sample( - 37, seed=123456).get_shape(), ( - 37, - 3,)) - - _check(ds.StudentT( - df=[ - 2., - 3., - 4., - ], mu=2., sigma=1.)) - _check(ds.StudentT( - df=7., mu=[ - 2., - 3., - 4., - ], sigma=1.)) - _check(ds.StudentT( - df=7., mu=3., sigma=[ - 2., - 3., - 4., - ])) + self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) + + _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) + _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) + _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) def testBroadcastingPdfArgs(self): @@ -285,24 +266,9 @@ class StudentTTest(test.TestCase): xs = xs.T _assert_shape(student, xs, (3, 3)) - _check(ds.StudentT( - df=[ - 2., - 3., - 4., - ], mu=2., sigma=1.)) - _check(ds.StudentT( - df=7., mu=[ - 2., - 3., - 4., - ], sigma=1.)) - _check(ds.StudentT( - df=7., mu=3., sigma=[ - 2., - 3., - 4., - ])) + _check(ds.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) + _check(ds.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) + _check(ds.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) def _check2d(student): _assert_shape(student, 2., (1, 3)) @@ -313,24 +279,9 @@ class StudentTTest(test.TestCase): xs = xs.T _assert_shape(student, xs, (3, 3)) - _check2d(ds.StudentT( - df=[[ - 2., - 3., - 4., - ]], mu=2., sigma=1.)) - _check2d(ds.StudentT( - df=7., mu=[[ - 2., - 3., - 4., - ]], sigma=1.)) - _check2d(ds.StudentT( - df=7., mu=3., sigma=[[ - 2., - 3., - 4., - ]])) + _check2d(ds.StudentT(df=[[2., 3., 4.,]], mu=2., sigma=1.)) + _check2d(ds.StudentT(df=7., mu=[[2., 3., 4.,]], sigma=1.)) + _check2d(ds.StudentT(df=7., mu=3., sigma=[[2., 3., 4.,]])) def _check2d_rows(student): _assert_shape(student, 2., (3, 1)) @@ -355,8 +306,8 @@ class StudentTTest(test.TestCase): def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): with self.test_session(): mu = [1., 3.3, 4.4] - student = ds.StudentT( - df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.], allow_nan_stats=False) + student = ds.StudentT(df=[0.5, 5., 7.], mu=mu, sigma=[3., 2., 1.], + allow_nan_stats=False) with self.assertRaisesOpError("x < y"): student.mean().eval() @@ -364,8 +315,8 @@ class StudentTTest(test.TestCase): with self.test_session(): mu = [-2, 0., 1., 3.3, 4.4] sigma = [5., 4., 3., 2., 1.] - student = ds.StudentT( - df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma, allow_nan_stats=True) + student = ds.StudentT(df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=sigma, + allow_nan_stats=True) mean = student.mean().eval() self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) @@ -503,15 +454,15 @@ class StudentTTest(test.TestCase): def testNegativeDofFails(self): with self.test_session(): - student = ds.StudentT( - df=[2, -5.], mu=0., sigma=1., validate_args=True, name="S") + student = ds.StudentT(df=[2, -5.], mu=0., sigma=1., + validate_args=True, name="S") with self.assertRaisesOpError(r"Condition x > 0 did not hold"): student.mean().eval() def testNegativeScaleFails(self): with self.test_session(): - student = ds.StudentT( - df=[5.], mu=0., sigma=[[3.], [-2.]], validate_args=True, name="S") + student = ds.StudentT(df=[5.], mu=0., sigma=[[3.], [-2.]], + validate_args=True, name="S") with self.assertRaisesOpError(r"Condition x > 0 did not hold"): student.mean().eval() diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py index f0519ca79d..eb1c985290 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution.py +++ b/tensorflow/contrib/distributions/python/ops/distribution.py @@ -39,7 +39,7 @@ from tensorflow.python.ops import math_ops _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ "batch_shape", "get_batch_shape", "event_shape", "get_event_shape", - "sample_n", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function", + "sample", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function", "survival_function", "entropy", "mean", "variance", "std", "mode"] @@ -562,61 +562,14 @@ class Distribution(_BaseDistribution): with self._name_scope(name, values=[sample_shape]): sample_shape = ops.convert_to_tensor( sample_shape, dtype=dtypes.int32, name="sample_shape") - if sample_shape.get_shape().ndims == 0: - return self.sample_n(sample_shape, seed, **condition_kwargs) - sample_shape, total = self._expand_sample_shape(sample_shape) - samples = self.sample_n(total, seed, **condition_kwargs) - output_shape = array_ops.concat_v2( - [sample_shape, array_ops.slice(array_ops.shape(samples), [1], [-1])], - 0) - output = array_ops.reshape(samples, output_shape) - output.set_shape(tensor_util.constant_value_as_shape( - sample_shape).concatenate(samples.get_shape()[1:])) - return output - - def sample_n(self, n, seed=None, name="sample_n", **condition_kwargs): - """Generate `n` samples. - - Args: - n: `Scalar` `Tensor` of type `int32` or `int64`, the number of - observations to sample. - seed: Python integer seed for RNG - name: name to give to the op. - **condition_kwargs: Named arguments forwarded to subclass implementation. - - Returns: - samples: a `Tensor` with a prepended dimension (n,). - - Raises: - TypeError: if `n` is not an integer type. - """ - warnings.warn("Please use `sample` instead of `sample_n`. `sample_n` " - "will be deprecated in December 2016.", - PendingDeprecationWarning) - with self._name_scope(name, values=[n]): - n = ops.convert_to_tensor(n, name="n") - if not n.dtype.is_integer: - raise TypeError("n.dtype=%s is not an integer type" % n.dtype) - x = self._sample_n(n, seed, **condition_kwargs) - - # Set shape hints. - sample_shape = tensor_shape.TensorShape( - tensor_util.constant_value(n)) - batch_ndims = self.get_batch_shape().ndims - event_ndims = self.get_event_shape().ndims - if batch_ndims is not None and event_ndims is not None: - inferred_shape = sample_shape.concatenate( - self.get_batch_shape().concatenate( - self.get_event_shape())) - x.set_shape(inferred_shape) - elif x.get_shape().ndims is not None and x.get_shape().ndims > 0: - x.get_shape()[0].merge_with(sample_shape[0]) - if batch_ndims is not None and batch_ndims > 0: - x.get_shape()[1:1+batch_ndims].merge_with(self.get_batch_shape()) - if event_ndims is not None and event_ndims > 0: - x.get_shape()[-event_ndims:].merge_with(self.get_event_shape()) - - return x + sample_shape, n = self._expand_sample_shape_to_vector( + sample_shape, "sample_shape") + samples = self._sample_n(n, seed, **condition_kwargs) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat_v2([sample_shape, batch_event_shape], 0) + samples = array_ops.reshape(samples, final_shape) + samples = self._set_sample_static_shape(samples, sample_shape) + return samples def _log_prob(self, value): raise NotImplementedError("log_prob is not implemented") @@ -938,33 +891,79 @@ class Distribution(_BaseDistribution): (values or []) + self._graph_parents)) as scope: yield scope - def _expand_sample_shape(self, sample_shape): - """Helper to `sample` which ensures sample_shape is 1D.""" - sample_shape_static_val = tensor_util.constant_value(sample_shape) - ndims = sample_shape.get_shape().ndims - if sample_shape_static_val is None: - if ndims is None or not sample_shape.get_shape().is_fully_defined(): - ndims = array_ops.rank(sample_shape) + def _expand_sample_shape_to_vector(self, x, name): + """Helper to `sample` which ensures input is 1D.""" + x_static_val = tensor_util.constant_value(x) + if x_static_val is None: + prod = math_ops.reduce_prod(x) + else: + prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) + + ndims = x.get_shape().ndims # != sample_ndims + if ndims is None: + # Maybe expand_dims. + ndims = array_ops.rank(x) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), - np.array((1,), dtype=dtypes.int32.as_numpy_dtype()), - array_ops.shape(sample_shape)) - sample_shape = array_ops.reshape(sample_shape, expanded_shape) - total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1 - else: - if ndims is None: - raise ValueError( - "Shouldn't be here; ndims cannot be none when we have a " - "tf.constant shape.") - if ndims == 0: - sample_shape_static_val = np.reshape(sample_shape_static_val, [1]) - sample_shape = ops.convert_to_tensor( - sample_shape_static_val, - dtype=dtypes.int32, - name="sample_shape") - total = np.prod(sample_shape_static_val, - dtype=dtypes.int32.as_numpy_dtype()) - return sample_shape, total + np.array([1], dtype=np.int32), + array_ops.shape(x)) + x = array_ops.reshape(x, expanded_shape) + elif ndims == 0: + # Definitely expand_dims. + if x_static_val is not None: + x = ops.convert_to_tensor( + np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), + name=name) + else: + x = array_ops.reshape(x, [1]) + elif ndims != 1: + raise ValueError("Input is neither scalar nor vector.") + + return x, prod + + def _set_sample_static_shape(self, x, sample_shape): + """Helper to `sample`; sets static shape info.""" + # Set shape hints. + sample_shape = tensor_shape.TensorShape( + tensor_util.constant_value(sample_shape)) + + ndims = x.get_shape().ndims + sample_ndims = sample_shape.ndims + batch_ndims = self.get_batch_shape().ndims + event_ndims = self.get_event_shape().ndims + + # Infer rank(x). + if (ndims is None and + sample_ndims is not None and + batch_ndims is not None and + event_ndims is not None): + ndims = sample_ndims + batch_ndims + event_ndims + x.set_shape([None] * ndims) + + # Infer sample shape. + if ndims is not None and sample_ndims is not None: + shape = sample_shape.concatenate([None]*(ndims - sample_ndims)) + x.set_shape(x.get_shape().merge_with(shape)) + + # Infer event shape. + if ndims is not None and event_ndims is not None: + shape = tensor_shape.TensorShape( + [None]*(ndims - event_ndims)).concatenate(self.get_event_shape()) + x.set_shape(x.get_shape().merge_with(shape)) + + # Infer batch shape. + if batch_ndims is not None: + if ndims is not None: + if sample_ndims is None and event_ndims is not None: + sample_ndims = ndims - batch_ndims - event_ndims + elif event_ndims is None and sample_ndims is not None: + event_ndims = ndims - batch_ndims - sample_ndims + if sample_ndims is not None and event_ndims is not None: + shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate( + self.get_batch_shape()).concatenate([None]*event_ndims) + x.set_shape(x.get_shape().merge_with(shape)) + + return x def _is_scalar_helper(self, static_shape_fn, dynamic_shape_fn): """Implementation for `is_scalar_batch` and `is_scalar_event`.""" diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 95ad9fe06a..0e98e9e3b0 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -244,7 +244,7 @@ class Mixture(distribution.Distribution): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n - cat_samples = self.cat.sample_n(n, seed=seed) + cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): @@ -308,7 +308,7 @@ class Mixture(distribution.Distribution): for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") - samples_class_c = self.components[c].sample_n(n_class, seed=seed) + samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index fd3ec553c0..713f09ccf8 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -277,7 +277,7 @@ class QuantizedDistribution(distributions.Distribution): upper_cutoff = self._upper_cutoff with ops.name_scope("transform"): n = ops.convert_to_tensor(n, name="n") - x_samps = self.distribution.sample_n(n=n, seed=seed) + x_samps = self.distribution.sample(n, seed=seed) ones = array_ops.ones_like(x_samps) # Snap values to the intervals (j - 1, j]. -- cgit v1.2.3