aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 12:00:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 13:06:24 -0800
commitfcbfead019e064ac796a19b9d8b05325d26b3115 (patch)
tree913a45eccacc32e3e869d61ec832038f7df94e2e
parent8a8fdb667bc77120e6ae47a88ab14e52cdd2cf07 (diff)
Add static_sample flag to Mixture, permitting calls to `sample` to not rely on dynamic tensor indexing. This allows for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture.
PiperOrigin-RevId: 183869189
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py40
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py142
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py39
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py35
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py45
5 files changed, 229 insertions, 72 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index a255d4fc89..31d24aa9ea 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -23,10 +23,15 @@ import itertools
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import mixture
+from tensorflow.contrib.distributions.python.ops import mixture_same_family
+from tensorflow.contrib.distributions.python.ops import mvn_diag
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import categorical
+from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.linalg import linear_operator_diag
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -395,6 +400,41 @@ class MixtureStddevTest(test.TestCase):
self.assertAllClose(actual_devs, expected_devs)
+class PadMixtureDimensionsTest(test.TestCase):
+
+ def test_pad_mixture_dimensions_mixture(self):
+ with self.test_session() as sess:
+ gm = mixture.Mixture(
+ cat=categorical.Categorical(probs=[[0.3, 0.7]]),
+ components=[
+ normal.Normal(loc=[-1.0], scale=[1.0]),
+ normal.Normal(loc=[1.0], scale=[0.5])
+ ])
+
+ x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
+ x_pad = distribution_util.pad_mixture_dimensions(
+ x, gm, gm.cat, gm.event_shape.ndims)
+ x_out, x_pad_out = sess.run([x, x_pad])
+
+ self.assertAllEqual(x_pad_out.shape, [2, 2])
+ self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
+
+ def test_pad_mixture_dimensions_mixture_same_family(self):
+ with self.test_session() as sess:
+ gm = mixture_same_family.MixtureSameFamily(
+ mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]),
+ components_distribution=mvn_diag.MultivariateNormalDiag(
+ loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5]))
+
+ x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
+ x_pad = distribution_util.pad_mixture_dimensions(
+ x, gm, gm.mixture_distribution, gm.event_shape.ndims)
+ x_out, x_pad_out = sess.run([x, x_pad])
+
+ self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
+ self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
+
+
class _PadTest(object):
def testNegAxisCorrectness(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 1e514fe0ff..0206489175 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -107,7 +107,7 @@ def _test_capture_normal_sample_outputs():
ds.Normal._call_sample_n = true_normal_call_sample_n
-def make_univariate_mixture(batch_shape, num_components):
+def make_univariate_mixture(batch_shape, num_components, use_static_graph):
batch_shape = ops.convert_to_tensor(batch_shape, dtypes.int32)
logits = random_ops.random_uniform(
array_ops.concat((batch_shape, [num_components]), axis=0),
@@ -119,11 +119,11 @@ def make_univariate_mixture(batch_shape, num_components):
for _ in range(num_components)
]
cat = ds.Categorical(logits, dtype=dtypes.int32)
- return ds.Mixture(cat, components)
+ return ds.Mixture(cat, components, use_static_graph=use_static_graph)
def make_multivariate_mixture(batch_shape, num_components, event_shape,
- batch_shape_tensor=None):
+ use_static_graph, batch_shape_tensor=None):
if batch_shape_tensor is None:
batch_shape_tensor = batch_shape
batch_shape_tensor = ops.convert_to_tensor(batch_shape_tensor, dtypes.int32)
@@ -145,15 +145,17 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape,
loc=loc, scale_diag=scale_diag)
components = [create_component() for _ in range(num_components)]
cat = ds.Categorical(logits, dtype=dtypes.int32)
- return ds.Mixture(cat, components)
+ return ds.Mixture(cat, components, use_static_graph=use_static_graph)
class MixtureTest(test.TestCase):
+ use_static_graph = False
def testShapes(self):
with self.test_session():
for batch_shape in ([], [1], [2, 3, 4]):
- dist = make_univariate_mixture(batch_shape, num_components=10)
+ dist = make_univariate_mixture(batch_shape, num_components=10,
+ use_static_graph=self.use_static_graph)
self.assertAllEqual(batch_shape, dist.batch_shape)
self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
self.assertAllEqual([], dist.event_shape)
@@ -161,7 +163,8 @@ class MixtureTest(test.TestCase):
for event_shape in ([1], [2]):
dist = make_multivariate_mixture(
- batch_shape, num_components=10, event_shape=event_shape)
+ batch_shape, num_components=10, event_shape=event_shape,
+ use_static_graph=self.use_static_graph)
self.assertAllEqual(batch_shape, dist.batch_shape)
self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
self.assertAllEqual(event_shape, dist.event_shape)
@@ -172,7 +175,8 @@ class MixtureTest(test.TestCase):
r"cat.num_classes != len"):
ds.Mixture(
ds.Categorical([0.1, 0.5]), # 2 classes
- [ds.Normal(loc=1.0, scale=2.0)])
+ [ds.Normal(loc=1.0, scale=2.0)],
+ use_static_graph=self.use_static_graph)
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
@@ -185,13 +189,15 @@ class MixtureTest(test.TestCase):
loc=1.0, scale=2.0), # scalar dist
ds.Normal(
loc=[1.0, 1.0], scale=[2.0, 2.0])
- ])
+ ],
+ use_static_graph=self.use_static_graph)
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32)
ds.Mixture(
ds.Categorical(cat_logits),
[ds.Normal(
- loc=[1.0], scale=[2.0])])
+ loc=[1.0], scale=[2.0])],
+ use_static_graph=self.use_static_graph)
def testBrokenShapesDynamic(self):
with self.test_session():
@@ -203,29 +209,37 @@ class MixtureTest(test.TestCase):
loc=d0_param, scale=d0_param), ds.Normal(
loc=d1_param, scale=d1_param)
],
- validate_args=True)
- with self.assertRaisesOpError(r"batch shape must match"):
+ validate_args=True,
+ use_static_graph=self.use_static_graph)
+
+ if self.use_static_graph:
+ error_string = r"Shapes of all inputs must match"
+ else:
+ error_string = r"batch shape must match"
+
+ with self.assertRaisesOpError(error_string):
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]})
- with self.assertRaisesOpError(r"batch shape must match"):
+ with self.assertRaisesOpError(error_string):
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: 1.0})
def testBrokenTypes(self):
with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"):
- ds.Mixture(None, [])
+ ds.Mixture(None, [], use_static_graph=self.use_static_graph)
cat = ds.Categorical([0.3, 0.2])
# components must be a list of distributions
with self.assertRaisesWithPredicateMatch(
TypeError, "all .* must be Distribution instances"):
- ds.Mixture(cat, [None])
+ ds.Mixture(cat, [None], use_static_graph=self.use_static_graph)
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
ds.Mixture(
cat, [
ds.Normal(loc=[1.0], scale=[2.0]),
ds.Normal(loc=[np.float16(1.0)],
scale=[np.float16(2.0)]),
- ])
+ ], use_static_graph=self.use_static_graph)
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
- ds.Mixture(ds.Categorical([0.3, 0.2]), None)
+ ds.Mixture(ds.Categorical([0.3, 0.2]), None,
+ use_static_graph=self.use_static_graph)
# TODO(ebrevdo): once distribution Domains have been added, add a
# test to ensure that the domains of the distributions in a
@@ -235,7 +249,8 @@ class MixtureTest(test.TestCase):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_univariate_mixture(
- batch_shape=batch_shape, num_components=2)
+ batch_shape=batch_shape, num_components=2,
+ use_static_graph=self.use_static_graph)
mean = dist.mean()
self.assertEqual(batch_shape, mean.get_shape())
@@ -256,7 +271,8 @@ class MixtureTest(test.TestCase):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
- batch_shape=batch_shape, num_components=2, event_shape=(4,))
+ batch_shape=batch_shape, num_components=2, event_shape=(4,),
+ use_static_graph=self.use_static_graph)
mean = dist.mean()
self.assertEqual(batch_shape + (4,), mean.get_shape())
@@ -283,7 +299,8 @@ class MixtureTest(test.TestCase):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_univariate_mixture(
- batch_shape=batch_shape, num_components=num_components)
+ batch_shape=batch_shape, num_components=num_components,
+ use_static_graph=self.use_static_graph)
dev = dist.stddev()
self.assertEqual(batch_shape, dev.get_shape())
@@ -325,7 +342,8 @@ class MixtureTest(test.TestCase):
dist = make_multivariate_mixture(
batch_shape=batch_shape,
num_components=num_components,
- event_shape=(4,))
+ event_shape=(4,),
+ use_static_graph=self.use_static_graph)
dev = dist.stddev()
self.assertEqual(batch_shape + (4,), dev.get_shape())
@@ -371,7 +389,8 @@ class MixtureTest(test.TestCase):
scale=component_devs[0]),
ds.Normal(loc=component_means[1],
scale=component_devs[1]),
- ])
+ ],
+ use_static_graph=self.use_static_graph)
mix_dev = mixture_dist.stddev()
with self.test_session() as sess:
actual_stddev = sess.run(mix_dev)
@@ -379,7 +398,8 @@ class MixtureTest(test.TestCase):
def testProbScalarUnivariate(self):
with self.test_session() as sess:
- dist = make_univariate_mixture(batch_shape=[], num_components=2)
+ dist = make_univariate_mixture(batch_shape=[], num_components=2,
+ use_static_graph=self.use_static_graph)
for x in [
np.array(
[1.0, 2.0], dtype=np.float32), np.array(
@@ -405,7 +425,8 @@ class MixtureTest(test.TestCase):
def testProbScalarMultivariate(self):
with self.test_session() as sess:
dist = make_multivariate_mixture(
- batch_shape=[], num_components=2, event_shape=[3])
+ batch_shape=[], num_components=2, event_shape=[3],
+ use_static_graph=self.use_static_graph)
for x in [
np.array(
[[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array(
@@ -432,7 +453,8 @@ class MixtureTest(test.TestCase):
def testProbBatchUnivariate(self):
with self.test_session() as sess:
- dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2)
+ dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2,
+ use_static_graph=self.use_static_graph)
for x in [
np.random.randn(2, 3).astype(np.float32),
@@ -459,7 +481,8 @@ class MixtureTest(test.TestCase):
def testProbBatchMultivariate(self):
with self.test_session() as sess:
dist = make_multivariate_mixture(
- batch_shape=[2, 3], num_components=2, event_shape=[4])
+ batch_shape=[2, 3], num_components=2, event_shape=[4],
+ use_static_graph=self.use_static_graph)
for x in [
np.random.randn(2, 3, 4).astype(np.float32),
@@ -487,7 +510,8 @@ class MixtureTest(test.TestCase):
num_components = 3
batch_shape = []
dist = make_univariate_mixture(
- batch_shape=batch_shape, num_components=num_components)
+ batch_shape=batch_shape, num_components=num_components,
+ use_static_graph=self.use_static_graph)
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
samples = dist.sample(n, seed=123)
@@ -502,7 +526,10 @@ class MixtureTest(test.TestCase):
which_c = np.where(cat_sample_values == c)[0]
size_c = which_c.size
# Scalar Batch univariate case: batch_size == 1, rank 1
- which_dist_samples = dist_sample_values[c][:size_c]
+ if self.use_static_graph:
+ which_dist_samples = dist_sample_values[c][which_c]
+ else:
+ which_dist_samples = dist_sample_values[c][:size_c]
self.assertAllClose(which_dist_samples, sample_values[which_c])
# Test that sampling with the same seed twice gives the same results.
@@ -522,7 +549,8 @@ class MixtureTest(test.TestCase):
]
cat = ds.Categorical(
logits, dtype=dtypes.int32, name="cat1")
- dist1 = ds.Mixture(cat, components, name="mixture1")
+ dist1 = ds.Mixture(cat, components, name="mixture1",
+ use_static_graph=self.use_static_graph)
samples1 = dist1.sample(n, seed=123456).eval()
random_seed.set_random_seed(654321)
@@ -532,7 +560,8 @@ class MixtureTest(test.TestCase):
]
cat2 = ds.Categorical(
logits, dtype=dtypes.int32, name="cat2")
- dist2 = ds.Mixture(cat2, components2, name="mixture2")
+ dist2 = ds.Mixture(cat2, components2, name="mixture2",
+ use_static_graph=self.use_static_graph)
samples2 = dist2.sample(n, seed=123456).eval()
self.assertAllClose(samples1, samples2)
@@ -541,7 +570,8 @@ class MixtureTest(test.TestCase):
with self.test_session() as sess:
num_components = 3
dist = make_multivariate_mixture(
- batch_shape=[], num_components=num_components, event_shape=[2])
+ batch_shape=[], num_components=num_components, event_shape=[2],
+ use_static_graph=self.use_static_graph)
n = 4
with _test_capture_mvndiag_sample_outputs() as component_samples:
samples = dist.sample(n, seed=123)
@@ -555,14 +585,18 @@ class MixtureTest(test.TestCase):
which_c = np.where(cat_sample_values == c)[0]
size_c = which_c.size
# Scalar Batch multivariate case: batch_size == 1, rank 2
- which_dist_samples = dist_sample_values[c][:size_c, :]
+ if self.use_static_graph:
+ which_dist_samples = dist_sample_values[c][which_c, :]
+ else:
+ which_dist_samples = dist_sample_values[c][:size_c, :]
self.assertAllClose(which_dist_samples, sample_values[which_c, :])
def testSampleBatchUnivariate(self):
with self.test_session() as sess:
num_components = 3
dist = make_univariate_mixture(
- batch_shape=[2, 3], num_components=num_components)
+ batch_shape=[2, 3], num_components=num_components,
+ use_static_graph=self.use_static_graph)
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
samples = dist.sample(n, seed=123)
@@ -576,8 +610,12 @@ class MixtureTest(test.TestCase):
which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c)
size_c = which_c_s.size
# Batch univariate case: batch_size == [2, 3], rank 3
- which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0,
- which_c_b1]
+ if self.use_static_graph:
+ which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0,
+ which_c_b1]
+ else:
+ which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0,
+ which_c_b1]
self.assertAllClose(which_dist_samples,
sample_values[which_c_s, which_c_b0, which_c_b1])
@@ -594,7 +632,8 @@ class MixtureTest(test.TestCase):
dist = make_multivariate_mixture(
batch_shape=batch_shape,
num_components=num_components, event_shape=[4],
- batch_shape_tensor=batch_shape_tensor)
+ batch_shape_tensor=batch_shape_tensor,
+ use_static_graph=self.use_static_graph)
n = 5
with _test_capture_mvndiag_sample_outputs() as component_samples:
samples = dist.sample(n, seed=123)
@@ -617,8 +656,12 @@ class MixtureTest(test.TestCase):
which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c)
size_c = which_c_s.size
# Batch univariate case: batch_size == [2, 3], rank 4 (multivariate)
- which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0,
- which_c_b1, :]
+ if self.use_static_graph:
+ which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0,
+ which_c_b1, :]
+ else:
+ which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0,
+ which_c_b1, :]
self.assertAllClose(which_dist_samples,
sample_values[which_c_s, which_c_b0, which_c_b1, :])
@@ -632,7 +675,8 @@ class MixtureTest(test.TestCase):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
- batch_shape=batch_shape, num_components=2, event_shape=(4,))
+ batch_shape=batch_shape, num_components=2, event_shape=(4,),
+ use_static_graph=self.use_static_graph)
entropy_lower_bound = dist.entropy_lower_bound()
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())
@@ -673,7 +717,8 @@ class MixtureTest(test.TestCase):
cat_tf = ds.Categorical(probs=mixture_weights)
components_tf = [ds.Normal(loc=mu, scale=sigma)
for (mu, sigma) in zip(means, sigmas)]
- mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf)
+ mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf,
+ use_static_graph=self.use_static_graph)
x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32)
@@ -721,7 +766,8 @@ class MixtureTest(test.TestCase):
cat_tf = ds.Categorical(probs=mixture_weights)
components_tf = [ds.Normal(loc=mu, scale=sigma)
for (mu, sigma) in zip(means, sigmas)]
- mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf)
+ mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf,
+ use_static_graph=self.use_static_graph)
x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32)
xs_to_check = [
@@ -760,12 +806,18 @@ class MixtureTest(test.TestCase):
gm = ds.Mixture(
cat=ds.Categorical(probs=[.3, .7]),
components=[ds.Gamma(1., 2.),
- ds.Gamma(2., 1.)])
+ ds.Gamma(2., 1.)],
+ use_static_graph=self.use_static_graph)
x_ = gm.sample().eval()
self.assertAllEqual([], x_.shape)
+class MixtureStaticSampleTest(MixtureTest):
+ use_static_graph = True
+
+
class MixtureBenchmark(test.Benchmark):
+ use_static_graph = False
def _runSamplingBenchmark(self, name, create_distribution, use_gpu,
num_components, batch_size, num_features,
@@ -811,7 +863,7 @@ class MixtureBenchmark(test.Benchmark):
components = list(
ds.MultivariateNormalDiag(
loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas))
- return ds.Mixture(cat, components)
+ return ds.Mixture(cat, components, use_static_graph=self.use_static_graph)
for use_gpu in False, True:
if use_gpu and not test.is_gpu_available():
@@ -853,7 +905,7 @@ class MixtureBenchmark(test.Benchmark):
ds.MultivariateNormalTriL(
loc=mu, scale_tril=linalg_ops.cholesky(sigma))
for (mu, sigma) in zip(mus, sigmas))
- return ds.Mixture(cat, components)
+ return ds.Mixture(cat, components, use_static_graph=self.use_static_graph)
for use_gpu in False, True:
if use_gpu and not test.is_gpu_available():
@@ -872,5 +924,9 @@ class MixtureBenchmark(test.Benchmark):
sample_size=sample_size)
+class MixtureStaticSampleBenchmark(MixtureBenchmark):
+ use_static_graph = True
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index a4d249d41e..289e1d50e1 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import linalg
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@@ -442,6 +443,44 @@ def maybe_check_scalar_distribution(
return assertions
+def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
+ event_ndims):
+ """Pad dimensions of event tensors for mixture distributions.
+
+ See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples.
+
+ Args:
+ x: event tensor to pad.
+ mixture_distribution: Base distribution of the mixture.
+ categorical_distribution: `Categorical` distribution that mixes the base
+ distribution.
+ event_ndims: Integer specifying the number of event dimensions in the event
+ tensor.
+
+ Returns:
+ A padded version of `x` that can broadcast with `categorical_distribution`.
+ """
+ with ops.name_scope("pad_mix_dims", values=[x]):
+ def _get_ndims(d):
+ if d.batch_shape.ndims is not None:
+ return d.batch_shape.ndims
+ return array_ops.shape(d.batch_shape_tensor())[0]
+ dist_batch_ndims = _get_ndims(mixture_distribution)
+ cat_batch_ndims = _get_ndims(categorical_distribution)
+ pad_ndims = array_ops.where(
+ categorical_distribution.is_scalar_batch(),
+ dist_batch_ndims,
+ dist_batch_ndims - cat_batch_ndims)
+ s = array_ops.shape(x)
+ x = array_ops.reshape(x, shape=array_ops.concat([
+ s[:-1],
+ array_ops.ones([pad_ndims], dtype=dtypes.int32),
+ s[-1:],
+ array_ops.ones([event_ndims], dtype=dtypes.int32),
+ ], axis=0))
+ return x
+
+
def static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index f2d492f548..cef6a143fc 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -71,6 +71,7 @@ class Mixture(distribution.Distribution):
components,
validate_args=False,
allow_nan_stats=True,
+ use_static_graph=False,
name="Mixture"):
"""Initialize a Mixture distribution.
@@ -96,6 +97,11 @@ class Mixture(distribution.Distribution):
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
+ use_static_graph: Calls to `sample` will not rely on dynamic tensor
+ indexing, allowing for some static graph compilation optimizations, but
+ at the expense of sampling all underlying distributions in the mixture.
+ (Possibly useful when running on TPUs).
+ Default value: `False` (i.e., use dynamic indexing).
name: A name for this distribution (optional).
Raises:
@@ -178,6 +184,10 @@ class Mixture(distribution.Distribution):
self._static_event_shape = static_event_shape
self._static_batch_shape = static_batch_shape
+ self._use_static_graph = use_static_graph
+ if use_static_graph and static_num_components is None:
+ raise ValueError("Number of categories must be known statically when "
+ "`static_sample=True`.")
# We let the Mixture distribution access _graph_parents since its arguably
# more like a baseclass.
graph_parents = self._cat._graph_parents # pylint: disable=protected-access
@@ -292,6 +302,31 @@ class Mixture(distribution.Distribution):
return mixture_log_cdf
def _sample_n(self, n, seed=None):
+ if self._use_static_graph:
+ # This sampling approach is almost the same as the approach used by
+ # `MixtureSameFamily`. The differences are due to having a list of
+ # `Distribution` objects rather than a single object, and maintaining
+ # random seed management that is consistent with the non-static code path.
+ samples = []
+ cat_samples = self.cat.sample(n, seed=seed)
+ for c in range(self.num_components):
+ seed = distribution_util.gen_new_seed(seed, "mixture")
+ samples.append(self.components[c].sample(n, seed=seed))
+ x = array_ops.stack(
+ samples, -self._static_event_shape.ndims - 1) # [n, B, k, E]
+ npdt = x.dtype.as_numpy_dtype
+ mask = array_ops.one_hot(
+ indices=cat_samples, # [n, B]
+ depth=self._num_components, # == k
+ on_value=np.ones([], dtype=npdt),
+ off_value=np.zeros([], dtype=npdt)) # [n, B, k]
+ mask = distribution_utils.pad_mixture_dimensions(
+ mask, self, self._cat,
+ self._static_event_shape.ndims) # [n, B, k, [1]*e]
+ return math_ops.reduce_sum(
+ x * mask,
+ axis=-1 - self._static_event_shape.ndims) # [n, B, E]
+
with ops.control_dependencies(self._assertions):
n = ops.convert_to_tensor(n, name="n")
static_n = tensor_util.constant_value(n)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 49afbea7f0..b93bdc5ab4 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.python.framework import dtypes
+from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -239,7 +239,9 @@ class MixtureSameFamily(distribution.Distribution):
depth=self._num_components, # == k
on_value=np.ones([], dtype=npdt),
off_value=np.zeros([], dtype=npdt)) # [n, B, k]
- mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e]
+ mask = distribution_utils.pad_mixture_dimensions(
+ mask, self, self.mixture_distribution,
+ self._event_shape().ndims) # [n, B, k, [1]*e]
return math_ops.reduce_sum(
x * mask, axis=-1 - self._event_ndims) # [n, B, E]
@@ -254,8 +256,9 @@ class MixtureSameFamily(distribution.Distribution):
def _mean(self):
with ops.control_dependencies(self._runtime_assertions):
- probs = self._pad_mix_dims(
- self.mixture_distribution.probs) # [B, k, [1]*e]
+ probs = distribution_utils.pad_mixture_dimensions(
+ self.mixture_distribution.probs, self, self.mixture_distribution,
+ self._event_shape().ndims) # [B, k, [1]*e]
return math_ops.reduce_sum(
probs * self.components_distribution.mean(),
axis=-1 - self._event_ndims) # [B, E]
@@ -271,8 +274,9 @@ class MixtureSameFamily(distribution.Distribution):
def _variance(self):
with ops.control_dependencies(self._runtime_assertions):
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
- probs = self._pad_mix_dims(
- self.mixture_distribution.probs) # [B, k, [1]*e]
+ probs = distribution_utils.pad_mixture_dimensions(
+ self.mixture_distribution.probs, self, self.mixture_distribution,
+ self._event_shape().ndims) # [B, k, [1]*e]
mean_cond_var = math_ops.reduce_sum(
probs * self.components_distribution.variance(),
axis=-1 - self._event_ndims) # [B, E]
@@ -291,8 +295,12 @@ class MixtureSameFamily(distribution.Distribution):
with ops.control_dependencies(self._runtime_assertions):
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
- probs = self._pad_mix_dims(self._pad_mix_dims(
- self.mixture_distribution.probs)) # [B, k, 1, 1]
+ probs = distribution_utils.pad_mixture_dimensions(
+ distribution_utils.pad_mixture_dimensions(
+ self.mixture_distribution.probs, self, self.mixture_distribution,
+ self._event_shape().ndims),
+ self, self.mixture_distribution,
+ self._event_shape().ndims) # [B, k, 1, 1]
mean_cond_var = math_ops.reduce_sum(
probs * self.components_distribution.covariance(),
axis=-3) # [B, e, e]
@@ -312,27 +320,6 @@ class MixtureSameFamily(distribution.Distribution):
shape[:d], [1], shape[d:]], axis=0))
return x
- def _pad_mix_dims(self, x):
- with ops.name_scope("pad_mix_dims", values=[x]):
- def _get_ndims(d):
- if d.batch_shape.ndims is not None:
- return d.batch_shape.ndims
- return array_ops.shape(d.batch_shape_tensor())[0]
- dist_batch_ndims = _get_ndims(self)
- cat_batch_ndims = _get_ndims(self.mixture_distribution)
- pad_ndims = array_ops.where(
- self.mixture_distribution.is_scalar_batch(),
- dist_batch_ndims,
- dist_batch_ndims - cat_batch_ndims)
- s = array_ops.shape(x)
- x = array_ops.reshape(x, shape=array_ops.concat([
- s[:-1],
- array_ops.ones([pad_ndims], dtype=dtypes.int32),
- s[-1:],
- array_ops.ones([self._event_ndims], dtype=dtypes.int32),
- ], axis=0))
- return x
-
def _outer_squared_difference(x, y):
"""Convenience function analogous to tf.squared_difference."""