diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py | 244 |
1 files changed, 142 insertions, 102 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index f4e63d79cd..6e72f1ca31 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -21,9 +21,19 @@ from __future__ import print_function import contextlib import numpy as np -import tensorflow as tf +from tensorflow.contrib import distributions +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test -distributions_py = tf.contrib.distributions +distributions_py = distributions def _swap_first_last_axes(array): @@ -65,33 +75,32 @@ def _test_capture_normal_sample_outputs(): def make_univariate_mixture(batch_shape, num_components): - logits = tf.random_uniform( - list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50. + logits = random_ops.random_uniform( + list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50. components = [ distributions_py.Normal( mu=np.float32(np.random.randn(*list(batch_shape))), sigma=np.float32(10 * np.random.rand(*list(batch_shape)))) for _ in range(num_components) ] - cat = distributions_py.Categorical(logits, dtype=tf.int32) + cat = distributions_py.Categorical(logits, dtype=dtypes.int32) return distributions_py.Mixture(cat, components) def make_multivariate_mixture(batch_shape, num_components, event_shape): - logits = tf.random_uniform( - list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50. + logits = random_ops.random_uniform( + list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50. components = [ distributions_py.MultivariateNormalDiag( mu=np.float32(np.random.randn(*list(batch_shape + event_shape))), diag_stdev=np.float32(10 * np.random.rand( - *list(batch_shape + event_shape)))) - for _ in range(num_components) + *list(batch_shape + event_shape)))) for _ in range(num_components) ] - cat = distributions_py.Categorical(logits, dtype=tf.int32) + cat = distributions_py.Categorical(logits, dtype=dtypes.int32) return distributions_py.Mixture(cat, components) -class MixtureTest(tf.test.TestCase): +class MixtureTest(test.TestCase): def testShapes(self): with self.test_session(): @@ -115,7 +124,8 @@ class MixtureTest(tf.test.TestCase): r"cat.num_classes != len"): distributions_py.Mixture( distributions_py.Categorical([0.1, 0.5]), # 2 classes - [distributions_py.Normal(mu=1.0, sigma=2.0)]) + [distributions_py.Normal( + mu=1.0, sigma=2.0)]) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -123,22 +133,29 @@ class MixtureTest(tf.test.TestCase): # vector of size (2,). distributions_py.Mixture( distributions_py.Categorical([-0.5, 0.5]), # scalar batch - [distributions_py.Normal(mu=1.0, sigma=2.0), # scalar dist - distributions_py.Normal(mu=[1.0, 1.0], sigma=[2.0, 2.0])]) + [ + distributions_py.Normal( + mu=1.0, sigma=2.0), # scalar dist + distributions_py.Normal( + mu=[1.0, 1.0], sigma=[2.0, 2.0]) + ]) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): - cat_logits = tf.placeholder(shape=[1, None], dtype=tf.float32) + cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) distributions_py.Mixture( distributions_py.Categorical(cat_logits), - [distributions_py.Normal(mu=[1.0], sigma=[2.0])]) + [distributions_py.Normal( + mu=[1.0], sigma=[2.0])]) def testBrokenShapesDynamic(self): with self.test_session(): - d0_param = tf.placeholder(dtype=tf.float32) - d1_param = tf.placeholder(dtype=tf.float32) + d0_param = array_ops.placeholder(dtype=dtypes.float32) + d1_param = array_ops.placeholder(dtype=dtypes.float32) d = distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.2]), - [distributions_py.Normal(mu=d0_param, sigma=d0_param), - distributions_py.Normal(mu=d1_param, sigma=d1_param)], + distributions_py.Categorical([0.1, 0.2]), [ + distributions_py.Normal( + mu=d0_param, sigma=d0_param), distributions_py.Normal( + mu=d1_param, sigma=d1_param) + ], validate_args=True) with self.assertRaisesOpError(r"batch shape must match"): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]}) @@ -155,18 +172,21 @@ class MixtureTest(tf.test.TestCase): distributions_py.Mixture(cat, [None]) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): distributions_py.Mixture( - cat, - [distributions_py.Normal(mu=[1.0], sigma=[2.0]), - distributions_py.Normal(mu=[np.float16(1.0)], - sigma=[np.float16(2.0)])]) + cat, [ + distributions_py.Normal( + mu=[1.0], sigma=[2.0]), distributions_py.Normal( + mu=[np.float16(1.0)], sigma=[np.float16(2.0)]) + ]) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None) with self.assertRaisesWithPredicateMatch(TypeError, "either be continuous or not"): distributions_py.Mixture( - cat, - [distributions_py.Normal(mu=[1.0], sigma=[2.0]), - distributions_py.Bernoulli(dtype=tf.float32, logits=[1.0])]) + cat, [ + distributions_py.Normal( + mu=[1.0], sigma=[2.0]), distributions_py.Bernoulli( + dtype=dtypes.float32, logits=[1.0]) + ]) def testMeanUnivariate(self): with self.test_session() as sess: @@ -176,7 +196,7 @@ class MixtureTest(tf.test.TestCase): mean = dist.mean() self.assertEqual(batch_shape, mean.get_shape()) - cat_probs = tf.nn.softmax(dist.cat.logits) + cat_probs = nn_ops.softmax(dist.cat.logits) dist_means = [d.mean() for d in dist.components] mean_value, cat_probs_value, dist_means_value = sess.run( @@ -197,7 +217,7 @@ class MixtureTest(tf.test.TestCase): mean = dist.mean() self.assertEqual(batch_shape + (4,), mean.get_shape()) - cat_probs = tf.nn.softmax(dist.cat.logits) + cat_probs = nn_ops.softmax(dist.cat.logits) dist_means = [d.mean() for d in dist.components] mean_value, cat_probs_value, dist_means_value = sess.run( @@ -217,23 +237,25 @@ class MixtureTest(tf.test.TestCase): def testProbScalarUnivariate(self): with self.test_session() as sess: dist = make_univariate_mixture(batch_shape=[], num_components=2) - for x in [np.array( - [1.0, 2.0], dtype=np.float32), np.array( - 1.0, dtype=np.float32), np.random.randn(3, 4).astype(np.float32)]: + for x in [ + np.array( + [1.0, 2.0], dtype=np.float32), np.array( + 1.0, dtype=np.float32), + np.random.randn(3, 4).astype(np.float32) + ]: p_x = dist.prob(x) self.assertEqual(x.shape, p_x.get_shape()) - cat_probs = tf.nn.softmax([dist.cat.logits])[0] + cat_probs = nn_ops.softmax([dist.cat.logits])[0] dist_probs = [d.prob(x) for d in dist.components] p_x_value, cat_probs_value, dist_probs_value = sess.run( [p_x, cat_probs, dist_probs]) self.assertEqual(x.shape, p_x_value.shape) - total_prob = sum( - c_p_value * d_p_value - for (c_p_value, d_p_value) - in zip(cat_probs_value, dist_probs_value)) + total_prob = sum(c_p_value * d_p_value + for (c_p_value, d_p_value + ) in zip(cat_probs_value, dist_probs_value)) self.assertAllClose(total_prob, p_x_value) @@ -241,15 +263,17 @@ class MixtureTest(tf.test.TestCase): with self.test_session() as sess: dist = make_multivariate_mixture( batch_shape=[], num_components=2, event_shape=[3]) - for x in [np.array( - [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( - [-1.0, 0.0, 1.0], dtype=np.float32), - np.random.randn(2, 2, 3).astype(np.float32)]: + for x in [ + np.array( + [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( + [-1.0, 0.0, 1.0], dtype=np.float32), + np.random.randn(2, 2, 3).astype(np.float32) + ]: p_x = dist.prob(x) self.assertEqual(x.shape[:-1], p_x.get_shape()) - cat_probs = tf.nn.softmax([dist.cat.logits])[0] + cat_probs = nn_ops.softmax([dist.cat.logits])[0] dist_probs = [d.prob(x) for d in dist.components] p_x_value, cat_probs_value, dist_probs_value = sess.run( @@ -267,12 +291,14 @@ class MixtureTest(tf.test.TestCase): with self.test_session() as sess: dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2) - for x in [np.random.randn(2, 3).astype(np.float32), - np.random.randn(4, 2, 3).astype(np.float32)]: + for x in [ + np.random.randn(2, 3).astype(np.float32), + np.random.randn(4, 2, 3).astype(np.float32) + ]: p_x = dist.prob(x) self.assertEqual(x.shape, p_x.get_shape()) - cat_probs = tf.nn.softmax(dist.cat.logits) + cat_probs = nn_ops.softmax(dist.cat.logits) dist_probs = [d.prob(x) for d in dist.components] p_x_value, cat_probs_value, dist_probs_value = sess.run( @@ -281,10 +307,9 @@ class MixtureTest(tf.test.TestCase): cat_probs_value = _swap_first_last_axes(cat_probs_value) - total_prob = sum( - c_p_value * d_p_value - for (c_p_value, d_p_value) - in zip(cat_probs_value, dist_probs_value)) + total_prob = sum(c_p_value * d_p_value + for (c_p_value, d_p_value + ) in zip(cat_probs_value, dist_probs_value)) self.assertAllClose(total_prob, p_x_value) @@ -293,12 +318,14 @@ class MixtureTest(tf.test.TestCase): dist = make_multivariate_mixture( batch_shape=[2, 3], num_components=2, event_shape=[4]) - for x in [np.random.randn(2, 3, 4).astype(np.float32), - np.random.randn(4, 2, 3, 4).astype(np.float32)]: + for x in [ + np.random.randn(2, 3, 4).astype(np.float32), + np.random.randn(4, 2, 3, 4).astype(np.float32) + ]: p_x = dist.prob(x) self.assertEqual(x.shape[:-1], p_x.get_shape()) - cat_probs = tf.nn.softmax(dist.cat.logits) + cat_probs = nn_ops.softmax(dist.cat.logits) dist_probs = [d.prob(x) for d in dist.components] p_x_value, cat_probs_value, dist_probs_value = sess.run( @@ -306,9 +333,9 @@ class MixtureTest(tf.test.TestCase): self.assertEqual(x.shape[:-1], p_x_value.shape) cat_probs_value = _swap_first_last_axes(cat_probs_value) - total_prob = sum( - c_p_value * d_p_value for (c_p_value, d_p_value) - in zip(cat_probs_value, dist_probs_value)) + total_prob = sum(c_p_value * d_p_value + for (c_p_value, d_p_value + ) in zip(cat_probs_value, dist_probs_value)) self.assertAllClose(total_prob, p_x_value) @@ -320,7 +347,7 @@ class MixtureTest(tf.test.TestCase): n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample_n(n, seed=123) - self.assertEqual(samples.dtype, tf.float32) + self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4,), samples.get_shape()) cat_samples = dist.cat.sample_n(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( @@ -344,17 +371,23 @@ class MixtureTest(tf.test.TestCase): with self.test_session(): n = 100 - tf.set_random_seed(654321) - components = [distributions_py.Normal( - mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)] - cat = distributions_py.Categorical(logits, dtype=tf.int32, name="cat1") + random_seed.set_random_seed(654321) + components = [ + distributions_py.Normal( + mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas) + ] + cat = distributions_py.Categorical( + logits, dtype=dtypes.int32, name="cat1") dist1 = distributions_py.Mixture(cat, components, name="mixture1") samples1 = dist1.sample_n(n, seed=123456).eval() - tf.set_random_seed(654321) - components2 = [distributions_py.Normal( - mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)] - cat2 = distributions_py.Categorical(logits, dtype=tf.int32, name="cat2") + random_seed.set_random_seed(654321) + components2 = [ + distributions_py.Normal( + mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas) + ] + cat2 = distributions_py.Categorical( + logits, dtype=dtypes.int32, name="cat2") dist2 = distributions_py.Mixture(cat2, components2, name="mixture2") samples2 = dist2.sample_n(n, seed=123456).eval() @@ -368,7 +401,7 @@ class MixtureTest(tf.test.TestCase): n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample_n(n, seed=123) - self.assertEqual(samples.dtype, tf.float32) + self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4, 2), samples.get_shape()) cat_samples = dist.cat.sample_n(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( @@ -389,7 +422,7 @@ class MixtureTest(tf.test.TestCase): n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample_n(n, seed=123) - self.assertEqual(samples.dtype, tf.float32) + self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((4, 2, 3), samples.get_shape()) cat_samples = dist.cat.sample_n(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( @@ -412,7 +445,7 @@ class MixtureTest(tf.test.TestCase): n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample_n(n, seed=123) - self.assertEqual(samples.dtype, tf.float32) + self.assertEqual(samples.dtype, dtypes.float32) self.assertEqual((5, 2, 3, 4), samples.get_shape()) cat_samples = dist.cat.sample_n(n, seed=123) sample_values, cat_sample_values, dist_sample_values = sess.run( @@ -436,7 +469,7 @@ class MixtureTest(tf.test.TestCase): entropy_lower_bound = dist.entropy_lower_bound() self.assertEqual(batch_shape, entropy_lower_bound.get_shape()) - cat_probs = tf.nn.softmax(dist.cat.logits) + cat_probs = nn_ops.softmax(dist.cat.logits) dist_entropy = [d.entropy() for d in dist.components] entropy_lower_bound_value, cat_probs_value, dist_entropy_value = ( @@ -453,32 +486,33 @@ class MixtureTest(tf.test.TestCase): self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value) -class MixtureBenchmark(tf.test.Benchmark): +class MixtureBenchmark(test.Benchmark): - def _runSamplingBenchmark(self, name, - create_distribution, use_gpu, num_components, - batch_size, num_features, sample_size): - config = tf.ConfigProto() + def _runSamplingBenchmark(self, name, create_distribution, use_gpu, + num_components, batch_size, num_features, + sample_size): + config = config_pb2.ConfigProto() config.allow_soft_placement = True np.random.seed(127) - with tf.Session(config=config, graph=tf.Graph()) as sess: - tf.set_random_seed(0) - with tf.device("/gpu:0" if use_gpu else "/cpu:0"): + with session.Session(config=config, graph=ops.Graph()) as sess: + random_seed.set_random_seed(0) + with ops.device("/gpu:0" if use_gpu else "/cpu:0"): mixture = create_distribution( num_components=num_components, batch_size=batch_size, num_features=num_features) sample_op = mixture.sample(sample_size).op - sess.run(tf.global_variables_initializer()) + sess.run(variables.global_variables_initializer()) reported = self.run_op_benchmark( - sess, sample_op, + sess, + sample_op, min_iters=10, - name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d" - % (name, use_gpu, num_components, - batch_size, num_features, sample_size))) - print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"]) - % (use_gpu, num_components, batch_size, - num_features, sample_size, reported["wall_time"])) + name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d" % + (name, use_gpu, num_components, batch_size, num_features, + sample_size))) + print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"]) % + (use_gpu, num_components, batch_size, num_features, sample_size, + reported["wall_time"])) def benchmarkSamplingMVNDiag(self): print("mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") @@ -487,25 +521,28 @@ class MixtureBenchmark(tf.test.Benchmark): cat = distributions_py.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ - tf.Variable(np.random.randn(batch_size, num_features)) - for _ in range(num_components)] + variables.Variable(np.random.randn(batch_size, num_features)) + for _ in range(num_components) + ] sigmas = [ - tf.Variable(np.random.rand(batch_size, num_features)) - for _ in range(num_components)] + variables.Variable(np.random.rand(batch_size, num_features)) + for _ in range(num_components) + ] components = list( - distributions_py.MultivariateNormalDiag(mu=mu, diag_stdev=sigma) - for (mu, sigma) in zip(mus, sigmas)) + distributions_py.MultivariateNormalDiag( + mu=mu, diag_stdev=sigma) for (mu, sigma) in zip(mus, sigmas)) return distributions_py.Mixture(cat, components) for use_gpu in False, True: - if use_gpu and not tf.test.is_gpu_available(): + if use_gpu and not test.is_gpu_available(): continue for num_components in 1, 8, 16: for batch_size in 1, 32: for num_features in 1, 64, 512: for sample_size in 1, 32, 128: self._runSamplingBenchmark( - "mvn_diag", create_distribution=create_distribution, + "mvn_diag", + create_distribution=create_distribution, use_gpu=use_gpu, num_components=num_components, batch_size=batch_size, @@ -523,26 +560,29 @@ class MixtureBenchmark(tf.test.Benchmark): cat = distributions_py.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ - tf.Variable(np.random.randn(batch_size, num_features)) - for _ in range(num_components)] + variables.Variable(np.random.randn(batch_size, num_features)) + for _ in range(num_components) + ] sigmas = [ - tf.Variable( + variables.Variable( psd(np.random.rand(batch_size, num_features, num_features))) - for _ in range(num_components)] + for _ in range(num_components) + ] components = list( - distributions_py.MultivariateNormalFull(mu=mu, sigma=sigma) - for (mu, sigma) in zip(mus, sigmas)) + distributions_py.MultivariateNormalFull( + mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas)) return distributions_py.Mixture(cat, components) for use_gpu in False, True: - if use_gpu and not tf.test.is_gpu_available(): + if use_gpu and not test.is_gpu_available(): continue for num_components in 1, 8, 16: for batch_size in 1, 32: for num_features in 1, 64, 512: for sample_size in 1, 32, 128: self._runSamplingBenchmark( - "mvn_full", create_distribution=create_distribution, + "mvn_full", + create_distribution=create_distribution, use_gpu=use_gpu, num_components=num_components, batch_size=batch_size, @@ -551,4 +591,4 @@ class MixtureBenchmark(tf.test.Benchmark): if __name__ == "__main__": - tf.test.main() + test.main() |