From 4e96e274443805df8afad5cb48f654fbf1776a4a Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 15 Sep 2016 19:22:14 -0800 Subject: Add set of benchmarks for sampling from tf.contrib.distributions.Mixture * Two benchmarks: one for MVNDiag (lightweight per-distribution construction) and one for MVNFull (complex per-distribution construction). * Minor tweaks to benchmark platform code to add logging and return benchmark statistics. Change: 133341455 --- .../python/kernel_tests/mixture_test.py | 99 ++++++++++++++++++++++ .../contrib/distributions/python/ops/mixture.py | 19 ++--- tensorflow/python/platform/benchmark.py | 25 ++++-- 3 files changed, 125 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 04cccdf073..0de7744f15 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -446,5 +446,104 @@ class MixtureTest(tf.test.TestCase): self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value) +class MixtureBenchmark(tf.test.Benchmark): + + def _runSamplingBenchmark(self, name, + create_distribution, use_gpu, num_components, + batch_size, num_features, sample_size): + config = tf.ConfigProto() + config.allow_soft_placement = True + np.random.seed(127) + with tf.Session(config=config, graph=tf.Graph()) as sess: + tf.set_random_seed(0) + with tf.device("/gpu:0" if use_gpu else "/cpu:0"): + mixture = create_distribution( + num_components=num_components, + batch_size=batch_size, + num_features=num_features) + sample_op = mixture.sample(sample_size).op + sess.run(tf.initialize_all_variables()) + reported = self.run_op_benchmark( + sess, sample_op, + min_iters=10, + name=("%s_%s_components_%d_batch_%d_features_%d_sample_%d" + % (name, use_gpu, num_components, + batch_size, num_features, sample_size))) + print("\t".join(["%s", "%d", "%d", "%d", "%d", "%g"]) + % (use_gpu, num_components, batch_size, + num_features, sample_size, reported["wall_time"])) + + def benchmarkSamplingMVNDiag(self): + print("mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") + + def create_distribution(batch_size, num_components, num_features): + cat = distributions_py.Categorical( + logits=np.random.randn(batch_size, num_components)) + mus = [ + tf.Variable(np.random.randn(batch_size, num_features)) + for _ in range(num_components)] + sigmas = [ + tf.Variable(np.random.rand(batch_size, num_features)) + for _ in range(num_components)] + components = list( + (distributions_py.MultivariateNormalDiag, + {"mu": mu, "diag_stdev": sigma}) + for (mu, sigma) in zip(mus, sigmas)) + return distributions_py.Mixture(cat, components) + + for use_gpu in False, True: + if use_gpu and not tf.test.is_gpu_available(): + continue + for num_components in 1, 8, 16: + for batch_size in 1, 32: + for num_features in 1, 64, 512: + for sample_size in 1, 32, 128: + self._runSamplingBenchmark( + "mvn_diag", create_distribution=create_distribution, + use_gpu=use_gpu, + num_components=num_components, + batch_size=batch_size, + num_features=num_features, + sample_size=sample_size) + + def benchmarkSamplingMVNFull(self): + print("mvn_full\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") + + def psd(x): + """Construct batch-wise PSD matrices.""" + return np.stack([np.dot(np.transpose(z), z) for z in x]) + + def create_distribution(batch_size, num_components, num_features): + cat = distributions_py.Categorical( + logits=np.random.randn(batch_size, num_components)) + mus = [ + tf.Variable(np.random.randn(batch_size, num_features)) + for _ in range(num_components)] + sigmas = [ + tf.Variable( + psd(np.random.rand(batch_size, num_features, num_features))) + for _ in range(num_components)] + components = list( + (distributions_py.MultivariateNormalFull, + {"mu": mu, "sigma": sigma}) + for (mu, sigma) in zip(mus, sigmas)) + return distributions_py.Mixture(cat, components) + + for use_gpu in False, True: + if use_gpu and not tf.test.is_gpu_available(): + continue + for num_components in 1, 8, 16: + for batch_size in 1, 32: + for num_features in 1, 64, 512: + for sample_size in 1, 32, 128: + self._runSamplingBenchmark( + "mvn_full", create_distribution=create_distribution, + use_gpu=use_gpu, + num_components=num_components, + batch_size=batch_size, + num_features=num_features, + sample_size=sample_size) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index ad1e99d08e..add31a5dd8 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -371,19 +371,12 @@ class Mixture(distribution.Distribution): lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) - - # Try to avoid a reshape to make the sample + batch one - # row (for array_ops.gather). This can be done only when - # the batch shape is known and is rank 1. - if static_batch_shape.ndims == 1: - samples_class_c = array_ops.gather( - samples_class_c, lookup_partitioned_batch_indices) - else: - samples_class_c = array_ops.reshape( - samples_class_c, - array_ops.concat(0, ([n_class * batch_size], event_shape))) - samples_class_c = array_ops.gather( - samples_class_c, lookup_partitioned_batch_indices) + samples_class_c = array_ops.reshape( + samples_class_c, + array_ops.concat(0, ([n_class * batch_size], event_shape))) + samples_class_c = array_ops.gather( + samples_class_c, lookup_partitioned_batch_indices, + name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index 23c03c38b1..949b64f517 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -34,6 +34,7 @@ from tensorflow.core.util import test_log_pb2 from tensorflow.python.client import timeline from tensorflow.python.platform import app from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging # When a subclass of the Benchmark class is created, it is added to # the registry automatically @@ -65,6 +66,15 @@ def _global_report_benchmark( if not isinstance(extras, dict): raise TypeError("extras must be a dict") + logging.info( + "Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," + "throughput: %g" % + (name, + iters if iters is not None else -1, + wall_time if wall_time is not None else -1, + cpu_time if cpu_time is not None else -1, + throughput if throughput is not None else -1)) + test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) if test_env is None: # Reporting was not requested @@ -209,6 +219,10 @@ class TensorFlowBenchmark(Benchmark): Otherwise it is inferred from the top-level method name. extras: (optional) Dict mapping string keys to additional benchmark info. Values may be either floats or values that are convertible to strings. + + Returns: + A `dict` containing the key-value pairs that were passed to + `report_benchmark`. """ for _ in range(burn_iters): sess.run(op_or_tensor, feed_dict=feed_dict) @@ -242,11 +256,12 @@ class TensorFlowBenchmark(Benchmark): median_delta = _median(deltas) - self.report_benchmark( - iters=min_iters, - wall_time=median_delta, - extras=extras, - name=name) + benchmark_values = {"iters": min_iters, + "wall_time": median_delta, + "extras": extras, + "name": name} + self.report_benchmark(**benchmark_values) + return benchmark_values def _run_benchmarks(regex): -- cgit v1.2.3