diff options
author | 2017-10-03 11:35:23 -0700 | |
---|---|---|
committer | 2017-10-03 11:38:44 -0700 | |
commit | 0e286d372b9c04e7db62fa88695282cc0a0d61d9 (patch) | |
tree | 233c14dbfb9780c56d40f916450a2398ed00d063 /tensorflow | |
parent | 7020f17de9eba436425c7fb61a2a026bdf80ed4f (diff) |
Bugfix: tf.random_gamma incorrectly handles non-batch, scalar draws.
PiperOrigin-RevId: 170887206
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py | 134 | ||||
-rw-r--r-- | tensorflow/core/kernels/random_op.cc | 3 |
2 files changed, 76 insertions, 61 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 61c2185e86..1e514fe0ff 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -38,7 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -distributions_py = distributions +ds = distributions def _swap_first_last_axes(array): @@ -74,7 +74,7 @@ def _test_capture_mvndiag_sample_outputs(): """Use monkey-patching to capture the output of an MVNDiag _call_sample_n.""" data_container = [] true_mvndiag_call_sample_n = ( - distributions_py.MultivariateNormalDiag._call_sample_n) + ds.MultivariateNormalDiag._call_sample_n) def _capturing_mvndiag_call_sample_n( self, sample_shape, seed, name, **kwargs): @@ -83,10 +83,10 @@ def _test_capture_mvndiag_sample_outputs(): data_container.append(samples) return samples - distributions_py.MultivariateNormalDiag._call_sample_n = ( + ds.MultivariateNormalDiag._call_sample_n = ( _capturing_mvndiag_call_sample_n) yield data_container - distributions_py.MultivariateNormalDiag._call_sample_n = ( + ds.MultivariateNormalDiag._call_sample_n = ( true_mvndiag_call_sample_n) @@ -94,7 +94,7 @@ def _test_capture_mvndiag_sample_outputs(): def _test_capture_normal_sample_outputs(): """Use monkey-patching to capture the output of an Normal _call_sample_n.""" data_container = [] - true_normal_call_sample_n = distributions_py.Normal._call_sample_n + true_normal_call_sample_n = ds.Normal._call_sample_n def _capturing_normal_call_sample_n(self, sample_shape, seed, name, **kwargs): samples = true_normal_call_sample_n( @@ -102,9 +102,9 @@ def _test_capture_normal_sample_outputs(): data_container.append(samples) return samples - distributions_py.Normal._call_sample_n = _capturing_normal_call_sample_n + ds.Normal._call_sample_n = _capturing_normal_call_sample_n yield data_container - distributions_py.Normal._call_sample_n = true_normal_call_sample_n + ds.Normal._call_sample_n = true_normal_call_sample_n def make_univariate_mixture(batch_shape, num_components): @@ -113,13 +113,13 @@ def make_univariate_mixture(batch_shape, num_components): array_ops.concat((batch_shape, [num_components]), axis=0), -1, 1, dtype=dtypes.float32) - 50. components = [ - distributions_py.Normal( + ds.Normal( loc=random_ops.random_normal(batch_shape), scale=10 * random_ops.random_uniform(batch_shape)) for _ in range(num_components) ] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) def make_multivariate_mixture(batch_shape, num_components, event_shape, @@ -141,11 +141,11 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, scale_diag = 10 * random_ops.random_uniform(batch_and_event_shape) loc.set_shape(static_batch_and_event_shape) scale_diag.set_shape(static_batch_and_event_shape) - return distributions_py.MultivariateNormalDiag( + return ds.MultivariateNormalDiag( loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) class MixtureTest(test.TestCase): @@ -170,37 +170,37 @@ class MixtureTest(test.TestCase): def testBrokenShapesStatic(self): with self.assertRaisesWithPredicateMatch(ValueError, r"cat.num_classes != len"): - distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.5]), # 2 classes - [distributions_py.Normal(loc=1.0, scale=2.0)]) + ds.Mixture( + ds.Categorical([0.1, 0.5]), # 2 classes + [ds.Normal(loc=1.0, scale=2.0)]) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the # Normals are not equal. One is a scalar, the other is a # vector of size (2,). - distributions_py.Mixture( - distributions_py.Categorical([-0.5, 0.5]), # scalar batch + ds.Mixture( + ds.Categorical([-0.5, 0.5]), # scalar batch [ - distributions_py.Normal( + ds.Normal( loc=1.0, scale=2.0), # scalar dist - distributions_py.Normal( + ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) ]) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) - distributions_py.Mixture( - distributions_py.Categorical(cat_logits), - [distributions_py.Normal( + ds.Mixture( + ds.Categorical(cat_logits), + [ds.Normal( loc=[1.0], scale=[2.0])]) def testBrokenShapesDynamic(self): with self.test_session(): d0_param = array_ops.placeholder(dtype=dtypes.float32) d1_param = array_ops.placeholder(dtype=dtypes.float32) - d = distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.2]), [ - distributions_py.Normal( - loc=d0_param, scale=d0_param), distributions_py.Normal( + d = ds.Mixture( + ds.Categorical([0.1, 0.2]), [ + ds.Normal( + loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], validate_args=True) @@ -211,21 +211,21 @@ class MixtureTest(test.TestCase): def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - distributions_py.Mixture(None, []) - cat = distributions_py.Categorical([0.3, 0.2]) + ds.Mixture(None, []) + cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - distributions_py.Mixture(cat, [None]) + ds.Mixture(cat, [None]) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): - distributions_py.Mixture( + ds.Mixture( cat, [ - distributions_py.Normal(loc=[1.0], scale=[2.0]), - distributions_py.Normal(loc=[np.float16(1.0)], - scale=[np.float16(2.0)]), + ds.Normal(loc=[1.0], scale=[2.0]), + ds.Normal(loc=[np.float16(1.0)], + scale=[np.float16(2.0)]), ]) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -364,13 +364,13 @@ class MixtureTest(test.TestCase): component_devs = np.array([0.05, 2.33]) ground_truth_stddev = 5.3120805 - mixture_dist = distributions_py.Mixture( - cat=distributions_py.Categorical(probs=cat_probs), + mixture_dist = ds.Mixture( + cat=ds.Categorical(probs=cat_probs), components=[ - distributions_py.Normal(loc=component_means[0], - scale=component_devs[0]), - distributions_py.Normal(loc=component_means[1], - scale=component_devs[1]), + ds.Normal(loc=component_means[0], + scale=component_devs[0]), + ds.Normal(loc=component_means[1], + scale=component_devs[1]), ]) mix_dev = mixture_dist.stddev() with self.test_session() as sess: @@ -517,22 +517,22 @@ class MixtureTest(test.TestCase): random_seed.set_random_seed(654321) components = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat = distributions_py.Categorical( + cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = distributions_py.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1") samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) components2 = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat2 = distributions_py.Categorical( + cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = distributions_py.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2") samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -665,15 +665,15 @@ class MixtureTest(test.TestCase): e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() - # Construct the distributions_py.Mixture object. + # Construct the ds.Mixture object. mixture_weights = _scalar_univariate_softmax(mixture_weight_logits) means = [np.random.uniform(low=-10, high=10, size=()).astype(np.float32) for _ in range(n_components)] sigmas = [np.ones(shape=(), dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -718,10 +718,10 @@ class MixtureTest(test.TestCase): for _ in range(n_components)] sigmas = [np.ones(shape=psize, dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -750,6 +750,20 @@ class MixtureTest(test.TestCase): self.assertAllClose(x_cdf_tf_result, scipy_cdf_result) self.assertAllClose(np.exp(x_log_cdf_tf_result), scipy_cdf_result) + def testSampleBimixGamma(self): + """Tests a bug in the underlying tf.Gamma op. + + Mixture's use of dynamic partition requires `random_gamma` correctly returns + an empty `Tensor`. + """ + with self.test_session(): + gm = ds.Mixture( + cat=ds.Categorical(probs=[.3, .7]), + components=[ds.Gamma(1., 2.), + ds.Gamma(2., 1.)]) + x_ = gm.sample().eval() + self.assertAllEqual([], x_.shape) + class MixtureBenchmark(test.Benchmark): @@ -784,7 +798,7 @@ class MixtureBenchmark(test.Benchmark): 2, "mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -795,9 +809,9 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalDiag( + ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -824,7 +838,7 @@ class MixtureBenchmark(test.Benchmark): return np.stack([np.dot(np.transpose(z), z) for z in x]) def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -836,10 +850,10 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalTriL( + ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index e78f8e2621..a37c757865 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -288,13 +288,14 @@ class RandomGammaOp : public OpKernel { &samples_shape)); } const int64 num_samples = samples_shape.num_elements(); - if (num_samples == 0) return; samples_shape.AppendShape(alpha_t.shape()); // Allocate output samples. Tensor* samples_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); + if (num_samples == 0) return; + using random::PhiloxRandom; typedef random::NormalDistribution<PhiloxRandom, double> Normal; |