aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py244
1 files changed, 142 insertions, 102 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index f4e63d79cd..6e72f1ca31 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -21,9 +21,19 @@ from __future__ import print_function
import contextlib
import numpy as np
-import tensorflow as tf
+from tensorflow.contrib import distributions
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
-distributions_py = tf.contrib.distributions
+distributions_py = distributions
def _swap_first_last_axes(array):
@@ -65,33 +75,32 @@ def _test_capture_normal_sample_outputs():
def make_univariate_mixture(batch_shape, num_components):
- logits = tf.random_uniform(
- list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
+ logits = random_ops.random_uniform(
+ list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50.
components = [
distributions_py.Normal(
mu=np.float32(np.random.randn(*list(batch_shape))),
sigma=np.float32(10 * np.random.rand(*list(batch_shape))))
for _ in range(num_components)
]
- cat = distributions_py.Categorical(logits, dtype=tf.int32)
+ cat = distributions_py.Categorical(logits, dtype=dtypes.int32)
return distributions_py.Mixture(cat, components)
def make_multivariate_mixture(batch_shape, num_components, event_shape):
- logits = tf.random_uniform(
- list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
+ logits = random_ops.random_uniform(
+ list(batch_shape) + [num_components], -1, 1, dtype=dtypes.float32) - 50.
components = [
distributions_py.MultivariateNormalDiag(
mu=np.float32(np.random.randn(*list(batch_shape + event_shape))),
diag_stdev=np.float32(10 * np.random.rand(
- *list(batch_shape + event_shape))))
- for _ in range(num_components)
+ *list(batch_shape + event_shape)))) for _ in range(num_components)
]
- cat = distributions_py.Categorical(logits, dtype=tf.int32)
+ cat = distributions_py.Categorical(logits, dtype=dtypes.int32)
return distributions_py.Mixture(cat, components)
-class MixtureTest(tf.test.TestCase):
+class MixtureTest(test.TestCase):
def testShapes(self):
with self.test_session():
@@ -115,7 +124,8 @@ class MixtureTest(tf.test.TestCase):
r"cat.num_classes != len"):
distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.5]), # 2 classes
- [distributions_py.Normal(mu=1.0, sigma=2.0)])
+ [distributions_py.Normal(
+ mu=1.0, sigma=2.0)])
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
@@ -123,22 +133,29 @@ class MixtureTest(tf.test.TestCase):
# vector of size (2,).
distributions_py.Mixture(
distributions_py.Categorical([-0.5, 0.5]), # scalar batch
- [distributions_py.Normal(mu=1.0, sigma=2.0), # scalar dist
- distributions_py.Normal(mu=[1.0, 1.0], sigma=[2.0, 2.0])])
+ [
+ distributions_py.Normal(
+ mu=1.0, sigma=2.0), # scalar dist
+ distributions_py.Normal(
+ mu=[1.0, 1.0], sigma=[2.0, 2.0])
+ ])
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
- cat_logits = tf.placeholder(shape=[1, None], dtype=tf.float32)
+ cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32)
distributions_py.Mixture(
distributions_py.Categorical(cat_logits),
- [distributions_py.Normal(mu=[1.0], sigma=[2.0])])
+ [distributions_py.Normal(
+ mu=[1.0], sigma=[2.0])])
def testBrokenShapesDynamic(self):
with self.test_session():
- d0_param = tf.placeholder(dtype=tf.float32)
- d1_param = tf.placeholder(dtype=tf.float32)
+ d0_param = array_ops.placeholder(dtype=dtypes.float32)
+ d1_param = array_ops.placeholder(dtype=dtypes.float32)
d = distributions_py.Mixture(
- distributions_py.Categorical([0.1, 0.2]),
- [distributions_py.Normal(mu=d0_param, sigma=d0_param),
- distributions_py.Normal(mu=d1_param, sigma=d1_param)],
+ distributions_py.Categorical([0.1, 0.2]), [
+ distributions_py.Normal(
+ mu=d0_param, sigma=d0_param), distributions_py.Normal(
+ mu=d1_param, sigma=d1_param)
+ ],
validate_args=True)
with self.assertRaisesOpError(r"batch shape must match"):
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]})
@@ -155,18 +172,21 @@ class MixtureTest(tf.test.TestCase):
distributions_py.Mixture(cat, [None])
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
distributions_py.Mixture(
- cat,
- [distributions_py.Normal(mu=[1.0], sigma=[2.0]),
- distributions_py.Normal(mu=[np.float16(1.0)],
- sigma=[np.float16(2.0)])])
+ cat, [
+ distributions_py.Normal(
+ mu=[1.0], sigma=[2.0]), distributions_py.Normal(
+ mu=[np.float16(1.0)], sigma=[np.float16(2.0)])
+ ])
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
with self.assertRaisesWithPredicateMatch(TypeError,
"either be continuous or not"):
distributions_py.Mixture(
- cat,
- [distributions_py.Normal(mu=[1.0], sigma=[2.0]),
- distributions_py.Bernoulli(dtype=tf.float32, logits=[1.0])])
+ cat, [
+ distributions_py.Normal(
+ mu=[1.0], sigma=[2.0]), distributions_py.Bernoulli(
+ dtype=dtypes.float32, logits=[1.0])
+ ])
def testMeanUnivariate(self):
with self.test_session() as sess:
@@ -176,7 +196,7 @@ class MixtureTest(tf.test.TestCase):
mean = dist.mean()
self.assertEqual(batch_shape, mean.get_shape())
- cat_probs = tf.nn.softmax(dist.cat.logits)
+ cat_probs = nn_ops.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.components]
mean_value, cat_probs_value, dist_means_value = sess.run(
@@ -197,7 +217,7 @@ class MixtureTest(tf.test.TestCase):
mean = dist.mean()
self.assertEqual(batch_shape + (4,), mean.get_shape())
- cat_probs = tf.nn.softmax(dist.cat.logits)
+ cat_probs = nn_ops.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.components]
mean_value, cat_probs_value, dist_means_value = sess.run(
@@ -217,23 +237,25 @@ class MixtureTest(tf.test.TestCase):
def testProbScalarUnivariate(self):
with self.test_session() as sess:
dist = make_univariate_mixture(batch_shape=[], num_components=2)
- for x in [np.array(
- [1.0, 2.0], dtype=np.float32), np.array(
- 1.0, dtype=np.float32), np.random.randn(3, 4).astype(np.float32)]:
+ for x in [
+ np.array(
+ [1.0, 2.0], dtype=np.float32), np.array(
+ 1.0, dtype=np.float32),
+ np.random.randn(3, 4).astype(np.float32)
+ ]:
p_x = dist.prob(x)
self.assertEqual(x.shape, p_x.get_shape())
- cat_probs = tf.nn.softmax([dist.cat.logits])[0]
+ cat_probs = nn_ops.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
self.assertEqual(x.shape, p_x_value.shape)
- total_prob = sum(
- c_p_value * d_p_value
- for (c_p_value, d_p_value)
- in zip(cat_probs_value, dist_probs_value))
+ total_prob = sum(c_p_value * d_p_value
+ for (c_p_value, d_p_value
+ ) in zip(cat_probs_value, dist_probs_value))
self.assertAllClose(total_prob, p_x_value)
@@ -241,15 +263,17 @@ class MixtureTest(tf.test.TestCase):
with self.test_session() as sess:
dist = make_multivariate_mixture(
batch_shape=[], num_components=2, event_shape=[3])
- for x in [np.array(
- [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array(
- [-1.0, 0.0, 1.0], dtype=np.float32),
- np.random.randn(2, 2, 3).astype(np.float32)]:
+ for x in [
+ np.array(
+ [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array(
+ [-1.0, 0.0, 1.0], dtype=np.float32),
+ np.random.randn(2, 2, 3).astype(np.float32)
+ ]:
p_x = dist.prob(x)
self.assertEqual(x.shape[:-1], p_x.get_shape())
- cat_probs = tf.nn.softmax([dist.cat.logits])[0]
+ cat_probs = nn_ops.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
@@ -267,12 +291,14 @@ class MixtureTest(tf.test.TestCase):
with self.test_session() as sess:
dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2)
- for x in [np.random.randn(2, 3).astype(np.float32),
- np.random.randn(4, 2, 3).astype(np.float32)]:
+ for x in [
+ np.random.randn(2, 3).astype(np.float32),
+ np.random.randn(4, 2, 3).astype(np.float32)
+ ]:
p_x = dist.prob(x)
self.assertEqual(x.shape, p_x.get_shape())
- cat_probs = tf.nn.softmax(dist.cat.logits)
+ cat_probs = nn_ops.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
@@ -281,10 +307,9 @@ class MixtureTest(tf.test.TestCase):
cat_probs_value = _swap_first_last_axes(cat_probs_value)
- total_prob = sum(
- c_p_value * d_p_value
- for (c_p_value, d_p_value)
- in zip(cat_probs_value, dist_probs_value))
+ total_prob = sum(c_p_value * d_p_value
+ for (c_p_value, d_p_value
+ ) in zip(cat_probs_value, dist_probs_value))
self.assertAllClose(total_prob, p_x_value)
@@ -293,12 +318,14 @@ class MixtureTest(tf.test.TestCase):
dist = make_multivariate_mixture(
batch_shape=[2, 3], num_components=2, event_shape=[4])
- for x in [np.random.randn(2, 3, 4).astype(np.float32),
- np.random.randn(4, 2, 3, 4).astype(np.float32)]:
+ for x in [
+ np.random.randn(2, 3, 4).astype(np.float32),
+ np.random.randn(4, 2, 3, 4).astype(np.float32)
+ ]:
p_x = dist.prob(x)
self.assertEqual(x.shape[:-1], p_x.get_shape())
- cat_probs = tf.nn.softmax(dist.cat.logits)
+ cat_probs = nn_ops.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
@@ -306,9 +333,9 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(x.shape[:-1], p_x_value.shape)
cat_probs_value = _swap_first_last_axes(cat_probs_value)
- total_prob = sum(
- c_p_value * d_p_value for (c_p_value, d_p_value)
- in zip(cat_probs_value, dist_probs_value))
+ total_prob = sum(c_p_value * d_p_value
+ for (c_p_value, d_p_value
+ ) in zip(cat_probs_value, dist_probs_value))
self.assertAllClose(total_prob, p_x_value)
@@ -320,7 +347,7 @@ class MixtureTest(tf.test.TestCase):
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
samples = dist.sample_n(n, seed=123)
- self.assertEqual(samples.dtype, tf.float32)
+ self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4,), samples.get_shape())
cat_samples = dist.cat.sample_n(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
@@ -344,17 +371,23 @@ class MixtureTest(tf.test.TestCase):
with self.test_session():
n = 100
- tf.set_random_seed(654321)
- components = [distributions_py.Normal(
- mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)]
- cat = distributions_py.Categorical(logits, dtype=tf.int32, name="cat1")
+ random_seed.set_random_seed(654321)
+ components = [
+ distributions_py.Normal(
+ mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)
+ ]
+ cat = distributions_py.Categorical(
+ logits, dtype=dtypes.int32, name="cat1")
dist1 = distributions_py.Mixture(cat, components, name="mixture1")
samples1 = dist1.sample_n(n, seed=123456).eval()
- tf.set_random_seed(654321)
- components2 = [distributions_py.Normal(
- mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)]
- cat2 = distributions_py.Categorical(logits, dtype=tf.int32, name="cat2")
+ random_seed.set_random_seed(654321)
+ components2 = [
+ distributions_py.Normal(
+ mu=mu, sigma=sigma) for mu, sigma in zip(mus, sigmas)
+ ]
+ cat2 = distributions_py.Categorical(
+ logits, dtype=dtypes.int32, name="cat2")
dist2 = distributions_py.Mixture(cat2, components2, name="mixture2")
samples2 = dist2.sample_n(n, seed=123456).eval()
@@ -368,7 +401,7 @@ class MixtureTest(tf.test.TestCase):
n = 4
with _test_capture_mvndiag_sample_outputs() as component_samples:
samples = dist.sample_n(n, seed=123)
- self.assertEqual(samples.dtype, tf.float32)
+ self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4, 2), samples.get_shape())
cat_samples = dist.cat.sample_n(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
@@ -389,7 +422,7 @@ class MixtureTest(tf.test.TestCase):
n = 4
with _test_capture_normal_sample_outputs() as component_samples:
samples = dist.sample_n(n, seed=123)
- self.assertEqual(samples.dtype, tf.float32)
+ self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((4, 2, 3), samples.get_shape())
cat_samples = dist.cat.sample_n(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
@@ -412,7 +445,7 @@ class MixtureTest(tf.test.TestCase):
n = 5
with _test_capture_mvndiag_sample_outputs() as component_samples:
samples = dist.sample_n(n, seed=123)
- self.assertEqual(samples.dtype, tf.float32)
+ self.assertEqual(samples.dtype, dtypes.float32)
self.assertEqual((5, 2, 3, 4), samples.get_shape())
cat_samples = dist.cat.sample_n(n, seed=123)
sample_values, cat_sample_values, dist_sample_values = sess.run(
@@ -436,7 +469,7 @@ class MixtureTest(tf.test.TestCase):
entropy_lower_bound = dist.entropy_lower_bound()
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())
- cat_probs = tf.nn.softmax(dist.cat.logits)
+ cat_probs = nn_ops.softmax(dist.cat.logits)
dist_entropy = [d.entropy() for d in dist.components]
entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
@@ -453,32 +486,33 @@ class MixtureTest(tf.test.TestCase):
self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value)
-class MixtureBenchmark(tf.test.Benchmark):
+class MixtureBenchmark(test.Benchmark):
- def _runSamplingBenchmark(self, name,
- create_distribution, use_gpu, num_components,
- batch_size, num_features, sample_size):
- config = tf.ConfigProto()
+ def _runSamplingBenchmark(self, name, create_distribution, use_gpu,
+ num_components, batch_size, num_features,
+ sample_size):
+ config = config_pb2.ConfigProto()
config.allow_soft_placement = True
np.random.seed(127)
- with tf.Session(config=config, graph=tf.Graph()) as sess:
- tf.set_random_seed(0)
- with tf.device("/gpu:0" if use_gpu else "/cpu:0"):
+ with session.Session(config=config, graph=ops.Graph()) as sess:
+ random_seed.set_random_seed(0)
+ with ops.device("/gpu:0" if use_gpu else "/cpu:0"):
mixture = create_distribution(
num_components=num_components,
batch_size=batch_size,
num_features=num_features)
sample_op = mixture.sample(sample_size).op
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
reported = self.run_op_benchmark(
- sess, sample_op,
+ sess,
+ sample_op,
min_iters=10,
- name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d"
- % (name, use_gpu, num_components,
- batch_size, num_features, sample_size)))
- print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"])
- % (use_gpu, num_components, batch_size,
- num_features, sample_size, reported["wall_time"]))
+ name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d" %
+ (name, use_gpu, num_components, batch_size, num_features,
+ sample_size)))
+ print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"]) %
+ (use_gpu, num_components, batch_size, num_features, sample_size,
+ reported["wall_time"]))
def benchmarkSamplingMVNDiag(self):
print("mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time")
@@ -487,25 +521,28 @@ class MixtureBenchmark(tf.test.Benchmark):
cat = distributions_py.Categorical(
logits=np.random.randn(batch_size, num_components))
mus = [
- tf.Variable(np.random.randn(batch_size, num_features))
- for _ in range(num_components)]
+ variables.Variable(np.random.randn(batch_size, num_features))
+ for _ in range(num_components)
+ ]
sigmas = [
- tf.Variable(np.random.rand(batch_size, num_features))
- for _ in range(num_components)]
+ variables.Variable(np.random.rand(batch_size, num_features))
+ for _ in range(num_components)
+ ]
components = list(
- distributions_py.MultivariateNormalDiag(mu=mu, diag_stdev=sigma)
- for (mu, sigma) in zip(mus, sigmas))
+ distributions_py.MultivariateNormalDiag(
+ mu=mu, diag_stdev=sigma) for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)
for use_gpu in False, True:
- if use_gpu and not tf.test.is_gpu_available():
+ if use_gpu and not test.is_gpu_available():
continue
for num_components in 1, 8, 16:
for batch_size in 1, 32:
for num_features in 1, 64, 512:
for sample_size in 1, 32, 128:
self._runSamplingBenchmark(
- "mvn_diag", create_distribution=create_distribution,
+ "mvn_diag",
+ create_distribution=create_distribution,
use_gpu=use_gpu,
num_components=num_components,
batch_size=batch_size,
@@ -523,26 +560,29 @@ class MixtureBenchmark(tf.test.Benchmark):
cat = distributions_py.Categorical(
logits=np.random.randn(batch_size, num_components))
mus = [
- tf.Variable(np.random.randn(batch_size, num_features))
- for _ in range(num_components)]
+ variables.Variable(np.random.randn(batch_size, num_features))
+ for _ in range(num_components)
+ ]
sigmas = [
- tf.Variable(
+ variables.Variable(
psd(np.random.rand(batch_size, num_features, num_features)))
- for _ in range(num_components)]
+ for _ in range(num_components)
+ ]
components = list(
- distributions_py.MultivariateNormalFull(mu=mu, sigma=sigma)
- for (mu, sigma) in zip(mus, sigmas))
+ distributions_py.MultivariateNormalFull(
+ mu=mu, sigma=sigma) for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)
for use_gpu in False, True:
- if use_gpu and not tf.test.is_gpu_available():
+ if use_gpu and not test.is_gpu_available():
continue
for num_components in 1, 8, 16:
for batch_size in 1, 32:
for num_features in 1, 64, 512:
for sample_size in 1, 32, 128:
self._runSamplingBenchmark(
- "mvn_full", create_distribution=create_distribution,
+ "mvn_full",
+ create_distribution=create_distribution,
use_gpu=use_gpu,
num_components=num_components,
batch_size=batch_size,
@@ -551,4 +591,4 @@ class MixtureBenchmark(tf.test.Benchmark):
if __name__ == "__main__":
- tf.test.main()
+ test.main()