aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-03-18 08:58:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 15:44:48 -0700
commite9bc4eff06d6b5bd212788d6e4a68c7d104f33d6 (patch)
tree9a24c93d46e1cdb392bf546af45dc8c029ba0a7b
parent6f192656a623e1008ff4ab92c1c3d67751003e65 (diff)
Add TensorFlow python Benchmark base class, registration mechanism, and test runner.
Outputs proto strings in a way similar to the reporter.cc in tensorflow/core/util/ Change: 117556944
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/framework/test_util.py5
-rw-r--r--tensorflow/python/kernel_tests/benchmark_test.py182
-rw-r--r--tensorflow/python/platform/benchmark.py203
-rw-r--r--tensorflow/python/platform/googletest.py15
-rw-r--r--tensorflow/python/platform/test.py4
6 files changed, 408 insertions, 3 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 370868dfc5..8a508bcd11 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -40,6 +40,7 @@ py_library(
name = "platform",
srcs = glob(["platform/**/*.py"]),
srcs_version = "PY2AND3",
+ deps = ["//tensorflow/core:protos_all_py"],
)
py_library(
@@ -1034,6 +1035,7 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/attention_ops_test.py",
"kernel_tests/barrier_ops_test.py",
"kernel_tests/bcast_ops_test.py",
+ "kernel_tests/benchmark_test.py",
"kernel_tests/candidate_sampler_ops_test.py",
"kernel_tests/cholesky_op_test.py",
"kernel_tests/clip_ops_test.py",
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 284db94d45..17f21d56af 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -165,9 +165,8 @@ class TensorFlowTestCase(googletest.TestCase):
text_format.Merge(expected_message_maybe_ascii, expected_message)
self._AssertProtoEquals(expected_message, message)
else:
- assert False, ("Can't compare protos of type " +
- type(expected_message_maybe_ascii) + " and " +
- type(message))
+ assert False, ("Can't compare protos of type %s and %s" %
+ (type(expected_message_maybe_ascii), type(message)))
def assertProtoEqualsVersion(
self, expected, actual, producer=versions.GRAPH_DEF_VERSION,
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
new file mode 100644
index 0000000000..181be9fffc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -0,0 +1,182 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow.python.framework.importer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.core.util import test_log_pb2
+from tensorflow.python.platform import benchmark
+
+
+# Used by SomeRandomBenchmark class below.
+_ran_somebenchmark_1 = [False]
+_ran_somebenchmark_2 = [False]
+_ran_somebenchmark_but_shouldnt = [False]
+
+
+class SomeRandomBenchmark(tf.test.Benchmark):
+ """This Benchmark should automatically be registered in the registry."""
+
+ def _dontRunThisBenchmark(self):
+ _ran_somebenchmark_but_shouldnt[0] = True
+
+ def notBenchmarkMethod(self):
+ _ran_somebenchmark_but_shouldnt[0] = True
+
+ def benchmark1(self):
+ _ran_somebenchmark_1[0] = True
+
+ def benchmark2(self):
+ _ran_somebenchmark_2[0] = True
+
+
+class TestReportingBenchmark(tf.test.Benchmark):
+ """This benchmark (maybe) reports some stuff."""
+
+ def benchmarkReport1(self):
+ self.report_benchmark(iters=1)
+
+ def benchmarkReport2(self):
+ self.report_benchmark(
+ iters=2, name="custom_benchmark_name",
+ extras={"number_key": 3, "other_key": "string"})
+
+
+class BenchmarkTest(tf.test.TestCase):
+
+ def testGlobalBenchmarkRegistry(self):
+ registry = list(benchmark.GLOBAL_BENCHMARK_REGISTRY)
+ self.assertEqual(len(registry), 2)
+ self.assertTrue(SomeRandomBenchmark in registry)
+ self.assertTrue(TestReportingBenchmark in registry)
+
+ def testRunSomeRandomBenchmark(self):
+ # Validate that SomeBenchmark has not run yet
+ self.assertFalse(_ran_somebenchmark_1[0])
+ self.assertFalse(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ # Don't run any benchmarks.
+ (exit_early, args, kwargs) = benchmark.run_benchmarks(
+ ["--flag1", "--flag3"], {"a": 3})
+
+ self.assertEqual(exit_early, False)
+ self.assertEqual(args, ["--flag1", "--flag3"])
+ self.assertEqual(kwargs, {"a": 3})
+
+ # Validate that SomeBenchmark has not run yet
+ self.assertFalse(_ran_somebenchmark_1[0])
+ self.assertFalse(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ # Run other benchmarks, but this wont run the one we care about
+ (exit_early, args, kwargs) = benchmark.run_benchmarks(
+ ["--flag1", "--benchmarks=__unrelated__", "--flag3"], {"a": 3})
+
+ self.assertEqual(exit_early, True)
+ self.assertEqual(args, ["--flag1", "--flag3"])
+ self.assertEqual(kwargs, {"a": 3})
+
+ # Validate that SomeBenchmark has not run yet
+ self.assertFalse(_ran_somebenchmark_1[0])
+ self.assertFalse(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ # Run all the benchmarks, avoid generating any reports
+ if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
+ del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
+ (exit_early, args, kwargs) = benchmark.run_benchmarks(
+ ["--flag1", "--benchmarks=.", "--flag3"], {"a": 3})
+
+ # Validate the output of run_benchmarks
+ self.assertEqual(exit_early, True)
+ self.assertEqual(args, ["--flag1", "--flag3"])
+ self.assertEqual(kwargs, {"a": 3})
+
+ # Validate that SomeRandomBenchmark ran correctly
+ self.assertTrue(_ran_somebenchmark_1[0])
+ self.assertTrue(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ def testReportingBenchmark(self):
+ tempdir = tf.test.get_temp_dir()
+ try:
+ tf.gfile.MakeDirs(tempdir)
+ except OSError as e:
+ # It's OK if the directory already exists.
+ if " exists:" not in str(e):
+ raise e
+
+ prefix = os.path.join(
+ tempdir, "reporting_bench_%016x_" % random.getrandbits(64))
+ expected_output_file = "%s%s" % (
+ prefix, "TestReportingBenchmark.benchmarkReport1")
+ expected_output_file_2 = "%s%s" % (
+ prefix, "TestReportingBenchmark.custom_benchmark_name")
+ try:
+ self.assertFalse(tf.gfile.Exists(expected_output_file))
+ # Run benchmark but without env, shouldn't write anything
+ if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
+ del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
+ reporting = TestReportingBenchmark()
+ reporting.benchmarkReport1() # This should run without writing anything
+ self.assertFalse(tf.gfile.Exists(expected_output_file))
+
+ # Runbenchmark with env, should write
+ os.environ[benchmark.TEST_REPORTER_TEST_ENV] = prefix
+
+ reporting = TestReportingBenchmark()
+ reporting.benchmarkReport1() # This should write
+ reporting.benchmarkReport2() # This should write
+
+ # Check the files were written
+ self.assertTrue(tf.gfile.Exists(expected_output_file))
+ self.assertTrue(tf.gfile.Exists(expected_output_file_2))
+
+ # Check the contents are correct
+ expected_1 = test_log_pb2.BenchmarkEntry()
+ expected_1.name = "TestReportingBenchmark.benchmarkReport1"
+ expected_1.iters = 1
+
+ expected_2 = test_log_pb2.BenchmarkEntry()
+ expected_2.name = "TestReportingBenchmark.custom_benchmark_name"
+ expected_2.iters = 2
+ expected_2.extras["number_key"].double_value = 3
+ expected_2.extras["other_key"].string_value = "string"
+
+ read_benchmark_1 = tf.gfile.GFile(expected_output_file, "r").read()
+ read_benchmark_1 = text_format.Merge(
+ read_benchmark_1, test_log_pb2.BenchmarkEntry())
+ self.assertProtoEquals(expected_1, read_benchmark_1)
+
+ read_benchmark_2 = tf.gfile.GFile(expected_output_file_2, "r").read()
+ read_benchmark_2 = text_format.Merge(
+ read_benchmark_2, test_log_pb2.BenchmarkEntry())
+ self.assertProtoEquals(expected_2, read_benchmark_2)
+
+ finally:
+ tf.gfile.DeleteRecursively(tempdir)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
new file mode 100644
index 0000000000..d6ae54a5e7
--- /dev/null
+++ b/tensorflow/python/platform/benchmark.py
@@ -0,0 +1,203 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities to run benchmarks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+import numbers
+import os
+import re
+
+
+from google.protobuf import text_format
+from tensorflow.core.util import test_log_pb2
+from tensorflow.python.platform import gfile
+
+# When a subclass of the Benchmark class is created, it is added to
+# the registry automatically
+GLOBAL_BENCHMARK_REGISTRY = set()
+
+# Environment variable that determines whether benchmarks are written.
+# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
+TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
+
+
+def _global_report_benchmark(
+ name, iters=None, cpu_time=None, wall_time=None,
+ throughput=None, extras=None):
+ """Method for recording a benchmark directly.
+
+ Args:
+ name: The BenchmarkEntry name.
+ iters: (optional) How many iterations were run
+ cpu_time: (optional) Total cpu time in seconds
+ wall_time: (optional) Total wall time in seconds
+ throughput: (optional) Throughput (in MB/s)
+ extras: (optional) Dict mapping string keys to additional benchmark info.
+
+ Raises:
+ TypeError: if extras is not a dict.
+ IOError: if the benchmark output file already exists.
+ """
+ if extras is not None:
+ if not isinstance(extras, dict):
+ raise TypeError("extras must be a dict")
+
+ test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
+ if test_env is None:
+ # Reporting was not requested
+ return
+
+ entry = test_log_pb2.BenchmarkEntry()
+ entry.name = name
+ if iters is not None:
+ entry.iters = iters
+ if cpu_time is not None:
+ entry.cpu_time = cpu_time
+ if wall_time is not None:
+ entry.wall_time = wall_time
+ if throughput is not None:
+ entry.throughput = throughput
+ if extras is not None:
+ for (k, v) in extras.items():
+ if isinstance(v, numbers.Number):
+ entry.extras[k].double_value = v
+ else:
+ entry.extras[k].string_value = str(v)
+
+ serialized_entry = text_format.MessageToString(entry)
+
+ mangled_name = name.replace("/", "__")
+ output_path = "%s%s" % (test_env, mangled_name)
+ if gfile.Exists(output_path):
+ raise IOError("File already exists: %s" % output_path)
+ with gfile.GFile(output_path, "w") as out:
+ out.write(serialized_entry)
+
+
+class _BenchmarkRegistrar(type):
+ """The Benchmark class registrar. Used by abstract Benchmark class."""
+
+ def __new__(mcs, clsname, base, attrs):
+ newclass = super(mcs, _BenchmarkRegistrar).__new__(
+ mcs, clsname, base, attrs)
+ if len(newclass.mro()) > 2:
+ # Only the base Benchmark abstract class has mro length 2.
+ # The rest subclass from it and are therefore registered.
+ GLOBAL_BENCHMARK_REGISTRY.add(newclass)
+ return newclass
+
+
+class Benchmark(object):
+ """Abstract class that provides helper functions for running benchmarks.
+
+ Any class subclassing this one is immediately registered in the global
+ benchmark registry.
+
+ Only methods whose names start with the word "benchmark" will be run during
+ benchmarking.
+ """
+ __metaclass__ = _BenchmarkRegistrar
+
+ def _get_name(self, overwrite_name):
+ """Returns full name of class and method calling report_benchmark."""
+
+ # Expect that the caller called report_benchmark, which called _get_name.
+ caller = inspect.stack()[2]
+ calling_class = caller[0].f_locals.get("self", None)
+ # Use the method name, or overwrite_name is provided.
+ name = overwrite_name if overwrite_name is not None else caller[3]
+ if calling_class is not None:
+ # Prefix the name with the class name.
+ class_name = type(calling_class).__name__
+ name = "%s.%s" % (class_name, name)
+ return name
+
+ def report_benchmark(
+ self,
+ iters=None,
+ cpu_time=None,
+ wall_time=None,
+ throughput=None,
+ extras=None,
+ name=None):
+ """Report a benchmark.
+
+ Args:
+ iters: (optional) How many iterations were run
+ cpu_time: (optional) Total cpu time in seconds
+ wall_time: (optional) Total wall time in seconds
+ throughput: (optional) Throughput (in MB/s)
+ extras: (optional) Dict mapping string keys to additional benchmark info.
+ name: (optional) Override the BenchmarkEntry name with `name`.
+ Otherwise it is inferred from the calling class and top-level
+ method name.
+ """
+ name = self._get_name(overwrite_name=name)
+ _global_report_benchmark(
+ name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
+ throughput=throughput, extras=extras)
+
+
+def _run_specific_benchmark(benchmark_class):
+ benchmark = benchmark_class()
+ attrs = dir(benchmark)
+ # Only run methods of this class whose names start with "benchmark"
+ for attr in attrs:
+ if not attr.startswith("benchmark"):
+ continue
+ benchmark_fn = getattr(benchmark, attr)
+ if not callable(benchmark_fn):
+ continue
+ # Call this benchmark method
+ benchmark_fn()
+
+
+def run_benchmarks(args, kwargs):
+ """Run benchmarks as declared in args.
+
+ Args:
+ args: List of args to main()
+ kwargs: List of kwargs to main()
+
+ Returns:
+ Tuple (early_exit, new_args, kwargs), where
+ early_exit: Bool, whether main() should now exit
+ new_args: Updated args for the remainder (having removed benchmark flags)
+ kwargs: Same as input kwargs.
+ """
+ exit_early = False
+
+ registry = list(GLOBAL_BENCHMARK_REGISTRY)
+
+ new_args = []
+ for arg in args:
+ if arg.startswith("--benchmarks="):
+ exit_early = True
+ regex = arg.split("=")[1]
+
+ # Match benchmarks in registry against regex
+ for benchmark in registry:
+ benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
+ if re.search(regex, benchmark_name):
+ # Found a match
+ _run_specific_benchmark(benchmark)
+ else:
+ new_args.append(arg)
+
+ return (exit_early, new_args, kwargs)
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index 2049bd2b1d..9bbaec2568 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -21,7 +21,22 @@ from __future__ import print_function
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
from . import control_imports
+from tensorflow.python.platform import benchmark
+
+# Import the Benchmark class
+Benchmark = benchmark.Benchmark # pylint: disable=invalid-name
+
if control_imports.USE_OSS and control_imports.OSS_GOOGLETEST:
from tensorflow.python.platform.default._googletest import *
+ from tensorflow.python.platform.default._googletest import main as g_main
else:
from tensorflow.python.platform.google._googletest import *
+ from tensorflow.python.platform.google._googletest import main as g_main
+
+
+# Redefine main to allow running benchmarks
+def main(*args, **kwargs):
+ (exit_early, args, kwargs) = benchmark.run_benchmarks(args, kwargs)
+ if exit_early:
+ return 0
+ return g_main(*args, **kwargs)
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index d2b9d1f974..6d78193233 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -72,6 +72,10 @@ from tensorflow.python.kernel_tests.gradient_checker import compute_gradient
# pylint: enable=unused-import
+# Import Benchmark class
+Benchmark = googletest.Benchmark # pylint: disable=invalid-name
+
+
def main():
"""Runs all unit tests."""
return googletest.main()