diff options
author | 2018-01-30 12:00:04 -0800 | |
---|---|---|
committer | 2018-01-30 13:06:24 -0800 | |
commit | fcbfead019e064ac796a19b9d8b05325d26b3115 (patch) | |
tree | 913a45eccacc32e3e869d61ec832038f7df94e2e | |
parent | 8a8fdb667bc77120e6ae47a88ab14e52cdd2cf07 (diff) |
Add static_sample flag to Mixture, permitting calls to `sample` to not rely on dynamic tensor indexing. This allows for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture.
PiperOrigin-RevId: 183869189
5 files changed, 229 insertions, 72 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index a255d4fc89..31d24aa9ea 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,10 +23,15 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mixture +from tensorflow.contrib.distributions.python.ops import mixture_same_family +from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -395,6 +400,41 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class PadMixtureDimensionsTest(test.TestCase): + + def test_pad_mixture_dimensions_mixture(self): + with self.test_session() as sess: + gm = mixture.Mixture( + cat=categorical.Categorical(probs=[[0.3, 0.7]]), + components=[ + normal.Normal(loc=[-1.0], scale=[1.0]), + normal.Normal(loc=[1.0], scale=[0.5]) + ]) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.cat, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + def test_pad_mixture_dimensions_mixture_same_family(self): + with self.test_session() as sess: + gm = mixture_same_family.MixtureSameFamily( + mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.mixture_distribution, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + class _PadTest(object): def testNegAxisCorrectness(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 1e514fe0ff..0206489175 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -107,7 +107,7 @@ def _test_capture_normal_sample_outputs(): ds.Normal._call_sample_n = true_normal_call_sample_n -def make_univariate_mixture(batch_shape, num_components): +def make_univariate_mixture(batch_shape, num_components, use_static_graph): batch_shape = ops.convert_to_tensor(batch_shape, dtypes.int32) logits = random_ops.random_uniform( array_ops.concat((batch_shape, [num_components]), axis=0), @@ -119,11 +119,11 @@ def make_univariate_mixture(batch_shape, num_components): for _ in range(num_components) ] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) def make_multivariate_mixture(batch_shape, num_components, event_shape, - batch_shape_tensor=None): + use_static_graph, batch_shape_tensor=None): if batch_shape_tensor is None: batch_shape_tensor = batch_shape batch_shape_tensor = ops.convert_to_tensor(batch_shape_tensor, dtypes.int32) @@ -145,15 +145,17 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) class MixtureTest(test.TestCase): + use_static_graph = False def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): - dist = make_univariate_mixture(batch_shape, num_components=10) + dist = make_univariate_mixture(batch_shape, num_components=10, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual([], dist.event_shape) @@ -161,7 +163,8 @@ class MixtureTest(test.TestCase): for event_shape in ([1], [2]): dist = make_multivariate_mixture( - batch_shape, num_components=10, event_shape=event_shape) + batch_shape, num_components=10, event_shape=event_shape, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual(event_shape, dist.event_shape) @@ -172,7 +175,8 @@ class MixtureTest(test.TestCase): r"cat.num_classes != len"): ds.Mixture( ds.Categorical([0.1, 0.5]), # 2 classes - [ds.Normal(loc=1.0, scale=2.0)]) + [ds.Normal(loc=1.0, scale=2.0)], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -185,13 +189,15 @@ class MixtureTest(test.TestCase): loc=1.0, scale=2.0), # scalar dist ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) - ]) + ], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) ds.Mixture( ds.Categorical(cat_logits), [ds.Normal( - loc=[1.0], scale=[2.0])]) + loc=[1.0], scale=[2.0])], + use_static_graph=self.use_static_graph) def testBrokenShapesDynamic(self): with self.test_session(): @@ -203,29 +209,37 @@ class MixtureTest(test.TestCase): loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], - validate_args=True) - with self.assertRaisesOpError(r"batch shape must match"): + validate_args=True, + use_static_graph=self.use_static_graph) + + if self.use_static_graph: + error_string = r"Shapes of all inputs must match" + else: + error_string = r"batch shape must match" + + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]}) - with self.assertRaisesOpError(r"batch shape must match"): + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: 1.0}) def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - ds.Mixture(None, []) + ds.Mixture(None, [], use_static_graph=self.use_static_graph) cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - ds.Mixture(cat, [None]) + ds.Mixture(cat, [None], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): ds.Mixture( cat, [ ds.Normal(loc=[1.0], scale=[2.0]), ds.Normal(loc=[np.float16(1.0)], scale=[np.float16(2.0)]), - ]) + ], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - ds.Mixture(ds.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None, + use_static_graph=self.use_static_graph) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -235,7 +249,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=2) + batch_shape=batch_shape, num_components=2, + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape, mean.get_shape()) @@ -256,7 +271,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape + (4,), mean.get_shape()) @@ -283,7 +299,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape, dev.get_shape()) @@ -325,7 +342,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, - event_shape=(4,)) + event_shape=(4,), + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape + (4,), dev.get_shape()) @@ -371,7 +389,8 @@ class MixtureTest(test.TestCase): scale=component_devs[0]), ds.Normal(loc=component_means[1], scale=component_devs[1]), - ]) + ], + use_static_graph=self.use_static_graph) mix_dev = mixture_dist.stddev() with self.test_session() as sess: actual_stddev = sess.run(mix_dev) @@ -379,7 +398,8 @@ class MixtureTest(test.TestCase): def testProbScalarUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[], num_components=2) + dist = make_univariate_mixture(batch_shape=[], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.array( [1.0, 2.0], dtype=np.float32), np.array( @@ -405,7 +425,8 @@ class MixtureTest(test.TestCase): def testProbScalarMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[], num_components=2, event_shape=[3]) + batch_shape=[], num_components=2, event_shape=[3], + use_static_graph=self.use_static_graph) for x in [ np.array( [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( @@ -432,7 +453,8 @@ class MixtureTest(test.TestCase): def testProbBatchUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2) + dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3).astype(np.float32), @@ -459,7 +481,8 @@ class MixtureTest(test.TestCase): def testProbBatchMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[2, 3], num_components=2, event_shape=[4]) + batch_shape=[2, 3], num_components=2, event_shape=[4], + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3, 4).astype(np.float32), @@ -487,7 +510,8 @@ class MixtureTest(test.TestCase): num_components = 3 batch_shape = [] dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -502,7 +526,10 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch univariate case: batch_size == 1, rank 1 - which_dist_samples = dist_sample_values[c][:size_c] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c] + else: + which_dist_samples = dist_sample_values[c][:size_c] self.assertAllClose(which_dist_samples, sample_values[which_c]) # Test that sampling with the same seed twice gives the same results. @@ -522,7 +549,8 @@ class MixtureTest(test.TestCase): ] cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = ds.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1", + use_static_graph=self.use_static_graph) samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) @@ -532,7 +560,8 @@ class MixtureTest(test.TestCase): ] cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = ds.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2", + use_static_graph=self.use_static_graph) samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -541,7 +570,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: num_components = 3 dist = make_multivariate_mixture( - batch_shape=[], num_components=num_components, event_shape=[2]) + batch_shape=[], num_components=num_components, event_shape=[2], + use_static_graph=self.use_static_graph) n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -555,14 +585,18 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch multivariate case: batch_size == 1, rank 2 - which_dist_samples = dist_sample_values[c][:size_c, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c, :] + else: + which_dist_samples = dist_sample_values[c][:size_c, :] self.assertAllClose(which_dist_samples, sample_values[which_c, :]) def testSampleBatchUnivariate(self): with self.test_session() as sess: num_components = 3 dist = make_univariate_mixture( - batch_shape=[2, 3], num_components=num_components) + batch_shape=[2, 3], num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -576,8 +610,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 3 - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1]) @@ -594,7 +632,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, event_shape=[4], - batch_shape_tensor=batch_shape_tensor) + batch_shape_tensor=batch_shape_tensor, + use_static_graph=self.use_static_graph) n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -617,8 +656,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 4 (multivariate) - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1, :] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1, :] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1, :]) @@ -632,7 +675,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) entropy_lower_bound = dist.entropy_lower_bound() self.assertEqual(batch_shape, entropy_lower_bound.get_shape()) @@ -673,7 +717,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -721,7 +766,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -760,12 +806,18 @@ class MixtureTest(test.TestCase): gm = ds.Mixture( cat=ds.Categorical(probs=[.3, .7]), components=[ds.Gamma(1., 2.), - ds.Gamma(2., 1.)]) + ds.Gamma(2., 1.)], + use_static_graph=self.use_static_graph) x_ = gm.sample().eval() self.assertAllEqual([], x_.shape) +class MixtureStaticSampleTest(MixtureTest): + use_static_graph = True + + class MixtureBenchmark(test.Benchmark): + use_static_graph = False def _runSamplingBenchmark(self, name, create_distribution, use_gpu, num_components, batch_size, num_features, @@ -811,7 +863,7 @@ class MixtureBenchmark(test.Benchmark): components = list( ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -853,7 +905,7 @@ class MixtureBenchmark(test.Benchmark): ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -872,5 +924,9 @@ class MixtureBenchmark(test.Benchmark): sample_size=sample_size) +class MixtureStaticSampleBenchmark(MixtureBenchmark): + use_static_graph = True + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index a4d249d41e..289e1d50e1 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -442,6 +443,44 @@ def maybe_check_scalar_distribution( return assertions +def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, + event_ndims): + """Pad dimensions of event tensors for mixture distributions. + + See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples. + + Args: + x: event tensor to pad. + mixture_distribution: Base distribution of the mixture. + categorical_distribution: `Categorical` distribution that mixes the base + distribution. + event_ndims: Integer specifying the number of event dimensions in the event + tensor. + + Returns: + A padded version of `x` that can broadcast with `categorical_distribution`. + """ + with ops.name_scope("pad_mix_dims", values=[x]): + def _get_ndims(d): + if d.batch_shape.ndims is not None: + return d.batch_shape.ndims + return array_ops.shape(d.batch_shape_tensor())[0] + dist_batch_ndims = _get_ndims(mixture_distribution) + cat_batch_ndims = _get_ndims(categorical_distribution) + pad_ndims = array_ops.where( + categorical_distribution.is_scalar_batch(), + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) + s = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([ + s[:-1], + array_ops.ones([pad_ndims], dtype=dtypes.int32), + s[-1:], + array_ops.ones([event_ndims], dtype=dtypes.int32), + ], axis=0)) + return x + + def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index f2d492f548..cef6a143fc 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -71,6 +71,7 @@ class Mixture(distribution.Distribution): components, validate_args=False, allow_nan_stats=True, + use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. @@ -96,6 +97,11 @@ class Mixture(distribution.Distribution): exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. + use_static_graph: Calls to `sample` will not rely on dynamic tensor + indexing, allowing for some static graph compilation optimizations, but + at the expense of sampling all underlying distributions in the mixture. + (Possibly useful when running on TPUs). + Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: @@ -178,6 +184,10 @@ class Mixture(distribution.Distribution): self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape + self._use_static_graph = use_static_graph + if use_static_graph and static_num_components is None: + raise ValueError("Number of categories must be known statically when " + "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access @@ -292,6 +302,31 @@ class Mixture(distribution.Distribution): return mixture_log_cdf def _sample_n(self, n, seed=None): + if self._use_static_graph: + # This sampling approach is almost the same as the approach used by + # `MixtureSameFamily`. The differences are due to having a list of + # `Distribution` objects rather than a single object, and maintaining + # random seed management that is consistent with the non-static code path. + samples = [] + cat_samples = self.cat.sample(n, seed=seed) + for c in range(self.num_components): + seed = distribution_util.gen_new_seed(seed, "mixture") + samples.append(self.components[c].sample(n, seed=seed)) + x = array_ops.stack( + samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] + npdt = x.dtype.as_numpy_dtype + mask = array_ops.one_hot( + indices=cat_samples, # [n, B] + depth=self._num_components, # == k + on_value=np.ones([], dtype=npdt), + off_value=np.zeros([], dtype=npdt)) # [n, B, k] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self._cat, + self._static_event_shape.ndims) # [n, B, k, [1]*e] + return math_ops.reduce_sum( + x * mask, + axis=-1 - self._static_event_shape.ndims) # [n, B, E] + with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 49afbea7f0..b93bdc5ab4 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import dtypes +from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -239,7 +239,9 @@ class MixtureSameFamily(distribution.Distribution): depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] - mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self.mixture_distribution, + self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._event_ndims) # [n, B, E] @@ -254,8 +256,9 @@ class MixtureSameFamily(distribution.Distribution): def _mean(self): with ops.control_dependencies(self._runtime_assertions): - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] return math_ops.reduce_sum( probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] @@ -271,8 +274,9 @@ class MixtureSameFamily(distribution.Distribution): def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] @@ -291,8 +295,12 @@ class MixtureSameFamily(distribution.Distribution): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims(self._pad_mix_dims( - self.mixture_distribution.probs)) # [B, k, 1, 1] + probs = distribution_utils.pad_mixture_dimensions( + distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims), + self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, 1, 1] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] @@ -312,27 +320,6 @@ class MixtureSameFamily(distribution.Distribution): shape[:d], [1], shape[d:]], axis=0)) return x - def _pad_mix_dims(self, x): - with ops.name_scope("pad_mix_dims", values=[x]): - def _get_ndims(d): - if d.batch_shape.ndims is not None: - return d.batch_shape.ndims - return array_ops.shape(d.batch_shape_tensor())[0] - dist_batch_ndims = _get_ndims(self) - cat_batch_ndims = _get_ndims(self.mixture_distribution) - pad_ndims = array_ops.where( - self.mixture_distribution.is_scalar_batch(), - dist_batch_ndims, - dist_batch_ndims - cat_batch_ndims) - s = array_ops.shape(x) - x = array_ops.reshape(x, shape=array_ops.concat([ - s[:-1], - array_ops.ones([pad_ndims], dtype=dtypes.int32), - s[-1:], - array_ops.ones([self._event_ndims], dtype=dtypes.int32), - ], axis=0)) - return x - def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" |