aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-10-03 11:35:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-03 11:38:44 -0700
commit0e286d372b9c04e7db62fa88695282cc0a0d61d9 (patch)
tree233c14dbfb9780c56d40f916450a2398ed00d063 /tensorflow
parent7020f17de9eba436425c7fb61a2a026bdf80ed4f (diff)
Bugfix: tf.random_gamma incorrectly handles non-batch, scalar draws.
PiperOrigin-RevId: 170887206
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py134
-rw-r--r--tensorflow/core/kernels/random_op.cc3
2 files changed, 76 insertions, 61 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 61c2185e86..1e514fe0ff 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -38,7 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
-distributions_py = distributions
+ds = distributions
def _swap_first_last_axes(array):
@@ -74,7 +74,7 @@ def _test_capture_mvndiag_sample_outputs():
"""Use monkey-patching to capture the output of an MVNDiag _call_sample_n."""
data_container = []
true_mvndiag_call_sample_n = (
- distributions_py.MultivariateNormalDiag._call_sample_n)
+ ds.MultivariateNormalDiag._call_sample_n)
def _capturing_mvndiag_call_sample_n(
self, sample_shape, seed, name, **kwargs):
@@ -83,10 +83,10 @@ def _test_capture_mvndiag_sample_outputs():
data_container.append(samples)
return samples
- distributions_py.MultivariateNormalDiag._call_sample_n = (
+ ds.MultivariateNormalDiag._call_sample_n = (
_capturing_mvndiag_call_sample_n)
yield data_container
- distributions_py.MultivariateNormalDiag._call_sample_n = (
+ ds.MultivariateNormalDiag._call_sample_n = (
true_mvndiag_call_sample_n)
@@ -94,7 +94,7 @@ def _test_capture_mvndiag_sample_outputs():
def _test_capture_normal_sample_outputs():
"""Use monkey-patching to capture the output of an Normal _call_sample_n."""
data_container = []
- true_normal_call_sample_n = distributions_py.Normal._call_sample_n
+ true_normal_call_sample_n = ds.Normal._call_sample_n
def _capturing_normal_call_sample_n(self, sample_shape, seed, name, **kwargs):
samples = true_normal_call_sample_n(
@@ -102,9 +102,9 @@ def _test_capture_normal_sample_outputs():
data_container.append(samples)
return samples
- distributions_py.Normal._call_sample_n = _capturing_normal_call_sample_n
+ ds.Normal._call_sample_n = _capturing_normal_call_sample_n
yield data_container
- distributions_py.Normal._call_sample_n = true_normal_call_sample_n
+ ds.Normal._call_sample_n = true_normal_call_sample_n
def make_univariate_mixture(batch_shape, num_components):
@@ -113,13 +113,13 @@ def make_univariate_mixture(batch_shape, num_components):
array_ops.concat((batch_shape, [num_components]), axis=0),
-1, 1, dtype=dtypes.float32) - 50.
components = [
- distributions_py.Normal(
+ ds.Normal(
loc=random_ops.random_normal(batch_shape),
scale=10 * random_ops.random_uniform(batch_shape))
for _ in range(num_components)
]
- cat = distributions_py.Categorical(logits, dtype=dtypes.int32)
- return distributions_py.Mixture(cat, components)
+ cat = ds.Categorical(logits, dtype=dtypes.int32)
+ return ds.Mixture(cat, components)
def make_multivariate_mixture(batch_shape, num_components, event_shape,
@@ -141,11 +141,11 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape,
scale_diag = 10 * random_ops.random_uniform(batch_and_event_shape)
loc.set_shape(static_batch_and_event_shape)
scale_diag.set_shape(static_batch_and_event_shape)
- return distributions_py.MultivariateNormalDiag(
+ return ds.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag)
components = [create_component() for _ in range(num_components)]
- cat = distributions_py.Categorical(logits, dtype=dtypes.int32)
- return distributions_py.Mixture(cat, components)
+ cat = ds.Categorical(logits, dtype=dtypes.int32)
+ return ds.Mixture(cat, components)
class MixtureTest(test.TestCase):
@@ -170,37 +170,37 @@ class MixtureTest(test.TestCase):
def testBrokenShapesStatic(self):
with self.assertRaisesWithPredicateMatch(ValueError,
r"cat.num_classes != len"):
- distributions_py.Mixture(
- distributions_py.Categorical([0.1, 0.5]), # 2 classes
- [distributions_py.Normal(loc=1.0, scale=2.0)])
+ ds.Mixture(
+ ds.Categorical([0.1, 0.5]), # 2 classes
+ [ds.Normal(loc=1.0, scale=2.0)])
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
# Normals are not equal. One is a scalar, the other is a
# vector of size (2,).
- distributions_py.Mixture(
- distributions_py.Categorical([-0.5, 0.5]), # scalar batch
+ ds.Mixture(
+ ds.Categorical([-0.5, 0.5]), # scalar batch
[
- distributions_py.Normal(
+ ds.Normal(
loc=1.0, scale=2.0), # scalar dist
- distributions_py.Normal(
+ ds.Normal(
loc=[1.0, 1.0], scale=[2.0, 2.0])
])
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32)
- distributions_py.Mixture(
- distributions_py.Categorical(cat_logits),
- [distributions_py.Normal(
+ ds.Mixture(
+ ds.Categorical(cat_logits),
+ [ds.Normal(
loc=[1.0], scale=[2.0])])
def testBrokenShapesDynamic(self):
with self.test_session():
d0_param = array_ops.placeholder(dtype=dtypes.float32)
d1_param = array_ops.placeholder(dtype=dtypes.float32)
- d = distributions_py.Mixture(
- distributions_py.Categorical([0.1, 0.2]), [
- distributions_py.Normal(
- loc=d0_param, scale=d0_param), distributions_py.Normal(
+ d = ds.Mixture(
+ ds.Categorical([0.1, 0.2]), [
+ ds.Normal(
+ loc=d0_param, scale=d0_param), ds.Normal(
loc=d1_param, scale=d1_param)
],
validate_args=True)
@@ -211,21 +211,21 @@ class MixtureTest(test.TestCase):
def testBrokenTypes(self):
with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"):
- distributions_py.Mixture(None, [])
- cat = distributions_py.Categorical([0.3, 0.2])
+ ds.Mixture(None, [])
+ cat = ds.Categorical([0.3, 0.2])
# components must be a list of distributions
with self.assertRaisesWithPredicateMatch(
TypeError, "all .* must be Distribution instances"):
- distributions_py.Mixture(cat, [None])
+ ds.Mixture(cat, [None])
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
- distributions_py.Mixture(
+ ds.Mixture(
cat, [
- distributions_py.Normal(loc=[1.0], scale=[2.0]),
- distributions_py.Normal(loc=[np.float16(1.0)],
- scale=[np.float16(2.0)]),
+ ds.Normal(loc=[1.0], scale=[2.0]),
+ ds.Normal(loc=[np.float16(1.0)],
+ scale=[np.float16(2.0)]),
])
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
- distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
+ ds.Mixture(ds.Categorical([0.3, 0.2]), None)
# TODO(ebrevdo): once distribution Domains have been added, add a
# test to ensure that the domains of the distributions in a
@@ -364,13 +364,13 @@ class MixtureTest(test.TestCase):
component_devs = np.array([0.05, 2.33])
ground_truth_stddev = 5.3120805
- mixture_dist = distributions_py.Mixture(
- cat=distributions_py.Categorical(probs=cat_probs),
+ mixture_dist = ds.Mixture(
+ cat=ds.Categorical(probs=cat_probs),
components=[
- distributions_py.Normal(loc=component_means[0],
- scale=component_devs[0]),
- distributions_py.Normal(loc=component_means[1],
- scale=component_devs[1]),
+ ds.Normal(loc=component_means[0],
+ scale=component_devs[0]),
+ ds.Normal(loc=component_means[1],
+ scale=component_devs[1]),
])
mix_dev = mixture_dist.stddev()
with self.test_session() as sess:
@@ -517,22 +517,22 @@ class MixtureTest(test.TestCase):
random_seed.set_random_seed(654321)
components = [
- distributions_py.Normal(
+ ds.Normal(
loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas)
]
- cat = distributions_py.Categorical(
+ cat = ds.Categorical(
logits, dtype=dtypes.int32, name="cat1")
- dist1 = distributions_py.Mixture(cat, components, name="mixture1")
+ dist1 = ds.Mixture(cat, components, name="mixture1")
samples1 = dist1.sample(n, seed=123456).eval()
random_seed.set_random_seed(654321)
components2 = [
- distributions_py.Normal(
+ ds.Normal(
loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas)
]
- cat2 = distributions_py.Categorical(
+ cat2 = ds.Categorical(
logits, dtype=dtypes.int32, name="cat2")
- dist2 = distributions_py.Mixture(cat2, components2, name="mixture2")
+ dist2 = ds.Mixture(cat2, components2, name="mixture2")
samples2 = dist2.sample(n, seed=123456).eval()
self.assertAllClose(samples1, samples2)
@@ -665,15 +665,15 @@ class MixtureTest(test.TestCase):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
- # Construct the distributions_py.Mixture object.
+ # Construct the ds.Mixture object.
mixture_weights = _scalar_univariate_softmax(mixture_weight_logits)
means = [np.random.uniform(low=-10, high=10, size=()).astype(np.float32)
for _ in range(n_components)]
sigmas = [np.ones(shape=(), dtype=np.float32) for _ in range(n_components)]
- cat_tf = distributions_py.Categorical(probs=mixture_weights)
- components_tf = [distributions_py.Normal(loc=mu, scale=sigma)
+ cat_tf = ds.Categorical(probs=mixture_weights)
+ components_tf = [ds.Normal(loc=mu, scale=sigma)
for (mu, sigma) in zip(means, sigmas)]
- mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf)
+ mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf)
x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32)
@@ -718,10 +718,10 @@ class MixtureTest(test.TestCase):
for _ in range(n_components)]
sigmas = [np.ones(shape=psize, dtype=np.float32)
for _ in range(n_components)]
- cat_tf = distributions_py.Categorical(probs=mixture_weights)
- components_tf = [distributions_py.Normal(loc=mu, scale=sigma)
+ cat_tf = ds.Categorical(probs=mixture_weights)
+ components_tf = [ds.Normal(loc=mu, scale=sigma)
for (mu, sigma) in zip(means, sigmas)]
- mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf)
+ mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf)
x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32)
xs_to_check = [
@@ -750,6 +750,20 @@ class MixtureTest(test.TestCase):
self.assertAllClose(x_cdf_tf_result, scipy_cdf_result)
self.assertAllClose(np.exp(x_log_cdf_tf_result), scipy_cdf_result)
+ def testSampleBimixGamma(self):
+ """Tests a bug in the underlying tf.Gamma op.
+
+ Mixture's use of dynamic partition requires `random_gamma` correctly returns
+ an empty `Tensor`.
+ """
+ with self.test_session():
+ gm = ds.Mixture(
+ cat=ds.Categorical(probs=[.3, .7]),
+ components=[ds.Gamma(1., 2.),
+ ds.Gamma(2., 1.)])
+ x_ = gm.sample().eval()
+ self.assertAllEqual([], x_.shape)
+
class MixtureBenchmark(test.Benchmark):
@@ -784,7 +798,7 @@ class MixtureBenchmark(test.Benchmark):
2, "mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time")
def create_distribution(batch_size, num_components, num_features):
- cat = distributions_py.Categorical(
+ cat = ds.Categorical(
logits=np.random.randn(batch_size, num_components))
mus = [
variables.Variable(np.random.randn(batch_size, num_features))
@@ -795,9 +809,9 @@ class MixtureBenchmark(test.Benchmark):
for _ in range(num_components)
]
components = list(
- distributions_py.MultivariateNormalDiag(
+ ds.MultivariateNormalDiag(
loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas))
- return distributions_py.Mixture(cat, components)
+ return ds.Mixture(cat, components)
for use_gpu in False, True:
if use_gpu and not test.is_gpu_available():
@@ -824,7 +838,7 @@ class MixtureBenchmark(test.Benchmark):
return np.stack([np.dot(np.transpose(z), z) for z in x])
def create_distribution(batch_size, num_components, num_features):
- cat = distributions_py.Categorical(
+ cat = ds.Categorical(
logits=np.random.randn(batch_size, num_components))
mus = [
variables.Variable(np.random.randn(batch_size, num_features))
@@ -836,10 +850,10 @@ class MixtureBenchmark(test.Benchmark):
for _ in range(num_components)
]
components = list(
- distributions_py.MultivariateNormalTriL(
+ ds.MultivariateNormalTriL(
loc=mu, scale_tril=linalg_ops.cholesky(sigma))
for (mu, sigma) in zip(mus, sigmas))
- return distributions_py.Mixture(cat, components)
+ return ds.Mixture(cat, components)
for use_gpu in False, True:
if use_gpu and not test.is_gpu_available():
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e78f8e2621..a37c757865 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -288,13 +288,14 @@ class RandomGammaOp : public OpKernel {
&samples_shape));
}
const int64 num_samples = samples_shape.num_elements();
- if (num_samples == 0) return;
samples_shape.AppendShape(alpha_t.shape());
// Allocate output samples.
Tensor* samples_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
+ if (num_samples == 0) return;
+
using random::PhiloxRandom;
typedef random::NormalDistribution<PhiloxRandom, double> Normal;