aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-09-15 19:22:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-15 20:32:51 -0700
commit4e96e274443805df8afad5cb48f654fbf1776a4a (patch)
tree4bf89ab642c10f1d59b93e62cd990857bde2f6dd
parente5b7e1e846f3e35d90a6bb260284b041d0036059 (diff)
Add set of benchmarks for sampling from tf.contrib.distributions.Mixture
* Two benchmarks: one for MVNDiag (lightweight per-distribution construction) and one for MVNFull (complex per-distribution construction). * Minor tweaks to benchmark platform code to add logging and return benchmark statistics. Change: 133341455
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py99
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py19
-rw-r--r--tensorflow/python/platform/benchmark.py25
3 files changed, 125 insertions, 18 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 04cccdf073..0de7744f15 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -446,5 +446,104 @@ class MixtureTest(tf.test.TestCase):
self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value)
+class MixtureBenchmark(tf.test.Benchmark):
+
+ def _runSamplingBenchmark(self, name,
+ create_distribution, use_gpu, num_components,
+ batch_size, num_features, sample_size):
+ config = tf.ConfigProto()
+ config.allow_soft_placement = True
+ np.random.seed(127)
+ with tf.Session(config=config, graph=tf.Graph()) as sess:
+ tf.set_random_seed(0)
+ with tf.device("/gpu:0" if use_gpu else "/cpu:0"):
+ mixture = create_distribution(
+ num_components=num_components,
+ batch_size=batch_size,
+ num_features=num_features)
+ sample_op = mixture.sample(sample_size).op
+ sess.run(tf.initialize_all_variables())
+ reported = self.run_op_benchmark(
+ sess, sample_op,
+ min_iters=10,
+ name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d"
+ % (name, use_gpu, num_components,
+ batch_size, num_features, sample_size)))
+ print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"])
+ % (use_gpu, num_components, batch_size,
+ num_features, sample_size, reported["wall_time"]))
+
+ def benchmarkSamplingMVNDiag(self):
+ print("mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time")
+
+ def create_distribution(batch_size, num_components, num_features):
+ cat = distributions_py.Categorical(
+ logits=np.random.randn(batch_size, num_components))
+ mus = [
+ tf.Variable(np.random.randn(batch_size, num_features))
+ for _ in range(num_components)]
+ sigmas = [
+ tf.Variable(np.random.rand(batch_size, num_features))
+ for _ in range(num_components)]
+ components = list(
+ (distributions_py.MultivariateNormalDiag,
+ {"mu": mu, "diag_stdev": sigma})
+ for (mu, sigma) in zip(mus, sigmas))
+ return distributions_py.Mixture(cat, components)
+
+ for use_gpu in False, True:
+ if use_gpu and not tf.test.is_gpu_available():
+ continue
+ for num_components in 1, 8, 16:
+ for batch_size in 1, 32:
+ for num_features in 1, 64, 512:
+ for sample_size in 1, 32, 128:
+ self._runSamplingBenchmark(
+ "mvn_diag", create_distribution=create_distribution,
+ use_gpu=use_gpu,
+ num_components=num_components,
+ batch_size=batch_size,
+ num_features=num_features,
+ sample_size=sample_size)
+
+ def benchmarkSamplingMVNFull(self):
+ print("mvn_full\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time")
+
+ def psd(x):
+ """Construct batch-wise PSD matrices."""
+ return np.stack([np.dot(np.transpose(z), z) for z in x])
+
+ def create_distribution(batch_size, num_components, num_features):
+ cat = distributions_py.Categorical(
+ logits=np.random.randn(batch_size, num_components))
+ mus = [
+ tf.Variable(np.random.randn(batch_size, num_features))
+ for _ in range(num_components)]
+ sigmas = [
+ tf.Variable(
+ psd(np.random.rand(batch_size, num_features, num_features)))
+ for _ in range(num_components)]
+ components = list(
+ (distributions_py.MultivariateNormalFull,
+ {"mu": mu, "sigma": sigma})
+ for (mu, sigma) in zip(mus, sigmas))
+ return distributions_py.Mixture(cat, components)
+
+ for use_gpu in False, True:
+ if use_gpu and not tf.test.is_gpu_available():
+ continue
+ for num_components in 1, 8, 16:
+ for batch_size in 1, 32:
+ for num_features in 1, 64, 512:
+ for sample_size in 1, 32, 128:
+ self._runSamplingBenchmark(
+ "mvn_full", create_distribution=create_distribution,
+ use_gpu=use_gpu,
+ num_components=num_components,
+ batch_size=batch_size,
+ num_features=num_features,
+ sample_size=sample_size)
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index ad1e99d08e..add31a5dd8 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -371,19 +371,12 @@ class Mixture(distribution.Distribution):
lookup_partitioned_batch_indices = (
batch_size * math_ops.range(n_class) +
partitioned_batch_indices[c])
-
- # Try to avoid a reshape to make the sample + batch one
- # row (for array_ops.gather). This can be done only when
- # the batch shape is known and is rank 1.
- if static_batch_shape.ndims == 1:
- samples_class_c = array_ops.gather(
- samples_class_c, lookup_partitioned_batch_indices)
- else:
- samples_class_c = array_ops.reshape(
- samples_class_c,
- array_ops.concat(0, ([n_class * batch_size], event_shape)))
- samples_class_c = array_ops.gather(
- samples_class_c, lookup_partitioned_batch_indices)
+ samples_class_c = array_ops.reshape(
+ samples_class_c,
+ array_ops.concat(0, ([n_class * batch_size], event_shape)))
+ samples_class_c = array_ops.gather(
+ samples_class_c, lookup_partitioned_batch_indices,
+ name="samples_class_c_gather")
samples_class[c] = samples_class_c
# Stitch back together the samples across the components.
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index 23c03c38b1..949b64f517 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -34,6 +34,7 @@ from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
# When a subclass of the Benchmark class is created, it is added to
# the registry automatically
@@ -65,6 +66,15 @@ def _global_report_benchmark(
if not isinstance(extras, dict):
raise TypeError("extras must be a dict")
+ logging.info(
+ "Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
+ "throughput: %g" %
+ (name,
+ iters if iters is not None else -1,
+ wall_time if wall_time is not None else -1,
+ cpu_time if cpu_time is not None else -1,
+ throughput if throughput is not None else -1))
+
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
if test_env is None:
# Reporting was not requested
@@ -209,6 +219,10 @@ class TensorFlowBenchmark(Benchmark):
Otherwise it is inferred from the top-level method name.
extras: (optional) Dict mapping string keys to additional benchmark info.
Values may be either floats or values that are convertible to strings.
+
+ Returns:
+ A `dict` containing the key-value pairs that were passed to
+ `report_benchmark`.
"""
for _ in range(burn_iters):
sess.run(op_or_tensor, feed_dict=feed_dict)
@@ -242,11 +256,12 @@ class TensorFlowBenchmark(Benchmark):
median_delta = _median(deltas)
- self.report_benchmark(
- iters=min_iters,
- wall_time=median_delta,
- extras=extras,
- name=name)
+ benchmark_values = {"iters": min_iters,
+ "wall_time": median_delta,
+ "extras": extras,
+ "name": name}
+ self.report_benchmark(**benchmark_values)
+ return benchmark_values
def _run_benchmarks(regex):