diff options
author | Eugene Brevdo <ebrevdo@gmail.com> | 2016-03-18 08:58:46 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-18 15:44:48 -0700 |
commit | e9bc4eff06d6b5bd212788d6e4a68c7d104f33d6 (patch) | |
tree | 9a24c93d46e1cdb392bf546af45dc8c029ba0a7b | |
parent | 6f192656a623e1008ff4ab92c1c3d67751003e65 (diff) |
Add TensorFlow python Benchmark base class, registration mechanism, and test runner.
Outputs proto strings in a way similar to the reporter.cc in tensorflow/core/util/
Change: 117556944
-rw-r--r-- | tensorflow/python/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 5 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/benchmark_test.py | 182 | ||||
-rw-r--r-- | tensorflow/python/platform/benchmark.py | 203 | ||||
-rw-r--r-- | tensorflow/python/platform/googletest.py | 15 | ||||
-rw-r--r-- | tensorflow/python/platform/test.py | 4 |
6 files changed, 408 insertions, 3 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 370868dfc5..8a508bcd11 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -40,6 +40,7 @@ py_library( name = "platform", srcs = glob(["platform/**/*.py"]), srcs_version = "PY2AND3", + deps = ["//tensorflow/core:protos_all_py"], ) py_library( @@ -1034,6 +1035,7 @@ cpu_only_kernel_test_list = glob([ "kernel_tests/attention_ops_test.py", "kernel_tests/barrier_ops_test.py", "kernel_tests/bcast_ops_test.py", + "kernel_tests/benchmark_test.py", "kernel_tests/candidate_sampler_ops_test.py", "kernel_tests/cholesky_op_test.py", "kernel_tests/clip_ops_test.py", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 284db94d45..17f21d56af 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -165,9 +165,8 @@ class TensorFlowTestCase(googletest.TestCase): text_format.Merge(expected_message_maybe_ascii, expected_message) self._AssertProtoEquals(expected_message, message) else: - assert False, ("Can't compare protos of type " + - type(expected_message_maybe_ascii) + " and " + - type(message)) + assert False, ("Can't compare protos of type %s and %s" % + (type(expected_message_maybe_ascii), type(message))) def assertProtoEqualsVersion( self, expected, actual, producer=versions.GRAPH_DEF_VERSION, diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py new file mode 100644 index 0000000000..181be9fffc --- /dev/null +++ b/tensorflow/python/kernel_tests/benchmark_test.py @@ -0,0 +1,182 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tensorflow.python.framework.importer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random + +import tensorflow as tf + +from google.protobuf import text_format +from tensorflow.core.util import test_log_pb2 +from tensorflow.python.platform import benchmark + + +# Used by SomeRandomBenchmark class below. +_ran_somebenchmark_1 = [False] +_ran_somebenchmark_2 = [False] +_ran_somebenchmark_but_shouldnt = [False] + + +class SomeRandomBenchmark(tf.test.Benchmark): + """This Benchmark should automatically be registered in the registry.""" + + def _dontRunThisBenchmark(self): + _ran_somebenchmark_but_shouldnt[0] = True + + def notBenchmarkMethod(self): + _ran_somebenchmark_but_shouldnt[0] = True + + def benchmark1(self): + _ran_somebenchmark_1[0] = True + + def benchmark2(self): + _ran_somebenchmark_2[0] = True + + +class TestReportingBenchmark(tf.test.Benchmark): + """This benchmark (maybe) reports some stuff.""" + + def benchmarkReport1(self): + self.report_benchmark(iters=1) + + def benchmarkReport2(self): + self.report_benchmark( + iters=2, name="custom_benchmark_name", + extras={"number_key": 3, "other_key": "string"}) + + +class BenchmarkTest(tf.test.TestCase): + + def testGlobalBenchmarkRegistry(self): + registry = list(benchmark.GLOBAL_BENCHMARK_REGISTRY) + self.assertEqual(len(registry), 2) + self.assertTrue(SomeRandomBenchmark in registry) + self.assertTrue(TestReportingBenchmark in registry) + + def testRunSomeRandomBenchmark(self): + # Validate that SomeBenchmark has not run yet + self.assertFalse(_ran_somebenchmark_1[0]) + self.assertFalse(_ran_somebenchmark_2[0]) + self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + + # Don't run any benchmarks. + (exit_early, args, kwargs) = benchmark.run_benchmarks( + ["--flag1", "--flag3"], {"a": 3}) + + self.assertEqual(exit_early, False) + self.assertEqual(args, ["--flag1", "--flag3"]) + self.assertEqual(kwargs, {"a": 3}) + + # Validate that SomeBenchmark has not run yet + self.assertFalse(_ran_somebenchmark_1[0]) + self.assertFalse(_ran_somebenchmark_2[0]) + self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + + # Run other benchmarks, but this wont run the one we care about + (exit_early, args, kwargs) = benchmark.run_benchmarks( + ["--flag1", "--benchmarks=__unrelated__", "--flag3"], {"a": 3}) + + self.assertEqual(exit_early, True) + self.assertEqual(args, ["--flag1", "--flag3"]) + self.assertEqual(kwargs, {"a": 3}) + + # Validate that SomeBenchmark has not run yet + self.assertFalse(_ran_somebenchmark_1[0]) + self.assertFalse(_ran_somebenchmark_2[0]) + self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + + # Run all the benchmarks, avoid generating any reports + if benchmark.TEST_REPORTER_TEST_ENV in os.environ: + del os.environ[benchmark.TEST_REPORTER_TEST_ENV] + (exit_early, args, kwargs) = benchmark.run_benchmarks( + ["--flag1", "--benchmarks=.", "--flag3"], {"a": 3}) + + # Validate the output of run_benchmarks + self.assertEqual(exit_early, True) + self.assertEqual(args, ["--flag1", "--flag3"]) + self.assertEqual(kwargs, {"a": 3}) + + # Validate that SomeRandomBenchmark ran correctly + self.assertTrue(_ran_somebenchmark_1[0]) + self.assertTrue(_ran_somebenchmark_2[0]) + self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + + def testReportingBenchmark(self): + tempdir = tf.test.get_temp_dir() + try: + tf.gfile.MakeDirs(tempdir) + except OSError as e: + # It's OK if the directory already exists. + if " exists:" not in str(e): + raise e + + prefix = os.path.join( + tempdir, "reporting_bench_%016x_" % random.getrandbits(64)) + expected_output_file = "%s%s" % ( + prefix, "TestReportingBenchmark.benchmarkReport1") + expected_output_file_2 = "%s%s" % ( + prefix, "TestReportingBenchmark.custom_benchmark_name") + try: + self.assertFalse(tf.gfile.Exists(expected_output_file)) + # Run benchmark but without env, shouldn't write anything + if benchmark.TEST_REPORTER_TEST_ENV in os.environ: + del os.environ[benchmark.TEST_REPORTER_TEST_ENV] + reporting = TestReportingBenchmark() + reporting.benchmarkReport1() # This should run without writing anything + self.assertFalse(tf.gfile.Exists(expected_output_file)) + + # Runbenchmark with env, should write + os.environ[benchmark.TEST_REPORTER_TEST_ENV] = prefix + + reporting = TestReportingBenchmark() + reporting.benchmarkReport1() # This should write + reporting.benchmarkReport2() # This should write + + # Check the files were written + self.assertTrue(tf.gfile.Exists(expected_output_file)) + self.assertTrue(tf.gfile.Exists(expected_output_file_2)) + + # Check the contents are correct + expected_1 = test_log_pb2.BenchmarkEntry() + expected_1.name = "TestReportingBenchmark.benchmarkReport1" + expected_1.iters = 1 + + expected_2 = test_log_pb2.BenchmarkEntry() + expected_2.name = "TestReportingBenchmark.custom_benchmark_name" + expected_2.iters = 2 + expected_2.extras["number_key"].double_value = 3 + expected_2.extras["other_key"].string_value = "string" + + read_benchmark_1 = tf.gfile.GFile(expected_output_file, "r").read() + read_benchmark_1 = text_format.Merge( + read_benchmark_1, test_log_pb2.BenchmarkEntry()) + self.assertProtoEquals(expected_1, read_benchmark_1) + + read_benchmark_2 = tf.gfile.GFile(expected_output_file_2, "r").read() + read_benchmark_2 = text_format.Merge( + read_benchmark_2, test_log_pb2.BenchmarkEntry()) + self.assertProtoEquals(expected_2, read_benchmark_2) + + finally: + tf.gfile.DeleteRecursively(tempdir) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py new file mode 100644 index 0000000000..d6ae54a5e7 --- /dev/null +++ b/tensorflow/python/platform/benchmark.py @@ -0,0 +1,203 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities to run benchmarks.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import numbers +import os +import re + + +from google.protobuf import text_format +from tensorflow.core.util import test_log_pb2 +from tensorflow.python.platform import gfile + +# When a subclass of the Benchmark class is created, it is added to +# the registry automatically +GLOBAL_BENCHMARK_REGISTRY = set() + +# Environment variable that determines whether benchmarks are written. +# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. +TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" + + +def _global_report_benchmark( + name, iters=None, cpu_time=None, wall_time=None, + throughput=None, extras=None): + """Method for recording a benchmark directly. + + Args: + name: The BenchmarkEntry name. + iters: (optional) How many iterations were run + cpu_time: (optional) Total cpu time in seconds + wall_time: (optional) Total wall time in seconds + throughput: (optional) Throughput (in MB/s) + extras: (optional) Dict mapping string keys to additional benchmark info. + + Raises: + TypeError: if extras is not a dict. + IOError: if the benchmark output file already exists. + """ + if extras is not None: + if not isinstance(extras, dict): + raise TypeError("extras must be a dict") + + test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) + if test_env is None: + # Reporting was not requested + return + + entry = test_log_pb2.BenchmarkEntry() + entry.name = name + if iters is not None: + entry.iters = iters + if cpu_time is not None: + entry.cpu_time = cpu_time + if wall_time is not None: + entry.wall_time = wall_time + if throughput is not None: + entry.throughput = throughput + if extras is not None: + for (k, v) in extras.items(): + if isinstance(v, numbers.Number): + entry.extras[k].double_value = v + else: + entry.extras[k].string_value = str(v) + + serialized_entry = text_format.MessageToString(entry) + + mangled_name = name.replace("/", "__") + output_path = "%s%s" % (test_env, mangled_name) + if gfile.Exists(output_path): + raise IOError("File already exists: %s" % output_path) + with gfile.GFile(output_path, "w") as out: + out.write(serialized_entry) + + +class _BenchmarkRegistrar(type): + """The Benchmark class registrar. Used by abstract Benchmark class.""" + + def __new__(mcs, clsname, base, attrs): + newclass = super(mcs, _BenchmarkRegistrar).__new__( + mcs, clsname, base, attrs) + if len(newclass.mro()) > 2: + # Only the base Benchmark abstract class has mro length 2. + # The rest subclass from it and are therefore registered. + GLOBAL_BENCHMARK_REGISTRY.add(newclass) + return newclass + + +class Benchmark(object): + """Abstract class that provides helper functions for running benchmarks. + + Any class subclassing this one is immediately registered in the global + benchmark registry. + + Only methods whose names start with the word "benchmark" will be run during + benchmarking. + """ + __metaclass__ = _BenchmarkRegistrar + + def _get_name(self, overwrite_name): + """Returns full name of class and method calling report_benchmark.""" + + # Expect that the caller called report_benchmark, which called _get_name. + caller = inspect.stack()[2] + calling_class = caller[0].f_locals.get("self", None) + # Use the method name, or overwrite_name is provided. + name = overwrite_name if overwrite_name is not None else caller[3] + if calling_class is not None: + # Prefix the name with the class name. + class_name = type(calling_class).__name__ + name = "%s.%s" % (class_name, name) + return name + + def report_benchmark( + self, + iters=None, + cpu_time=None, + wall_time=None, + throughput=None, + extras=None, + name=None): + """Report a benchmark. + + Args: + iters: (optional) How many iterations were run + cpu_time: (optional) Total cpu time in seconds + wall_time: (optional) Total wall time in seconds + throughput: (optional) Throughput (in MB/s) + extras: (optional) Dict mapping string keys to additional benchmark info. + name: (optional) Override the BenchmarkEntry name with `name`. + Otherwise it is inferred from the calling class and top-level + method name. + """ + name = self._get_name(overwrite_name=name) + _global_report_benchmark( + name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time, + throughput=throughput, extras=extras) + + +def _run_specific_benchmark(benchmark_class): + benchmark = benchmark_class() + attrs = dir(benchmark) + # Only run methods of this class whose names start with "benchmark" + for attr in attrs: + if not attr.startswith("benchmark"): + continue + benchmark_fn = getattr(benchmark, attr) + if not callable(benchmark_fn): + continue + # Call this benchmark method + benchmark_fn() + + +def run_benchmarks(args, kwargs): + """Run benchmarks as declared in args. + + Args: + args: List of args to main() + kwargs: List of kwargs to main() + + Returns: + Tuple (early_exit, new_args, kwargs), where + early_exit: Bool, whether main() should now exit + new_args: Updated args for the remainder (having removed benchmark flags) + kwargs: Same as input kwargs. + """ + exit_early = False + + registry = list(GLOBAL_BENCHMARK_REGISTRY) + + new_args = [] + for arg in args: + if arg.startswith("--benchmarks="): + exit_early = True + regex = arg.split("=")[1] + + # Match benchmarks in registry against regex + for benchmark in registry: + benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) + if re.search(regex, benchmark_name): + # Found a match + _run_specific_benchmark(benchmark) + else: + new_args.append(arg) + + return (exit_early, new_args, kwargs) diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index 2049bd2b1d..9bbaec2568 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -21,7 +21,22 @@ from __future__ import print_function # pylint: disable=g-import-not-at-top # pylint: disable=wildcard-import from . import control_imports +from tensorflow.python.platform import benchmark + +# Import the Benchmark class +Benchmark = benchmark.Benchmark # pylint: disable=invalid-name + if control_imports.USE_OSS and control_imports.OSS_GOOGLETEST: from tensorflow.python.platform.default._googletest import * + from tensorflow.python.platform.default._googletest import main as g_main else: from tensorflow.python.platform.google._googletest import * + from tensorflow.python.platform.google._googletest import main as g_main + + +# Redefine main to allow running benchmarks +def main(*args, **kwargs): + (exit_early, args, kwargs) = benchmark.run_benchmarks(args, kwargs) + if exit_early: + return 0 + return g_main(*args, **kwargs) diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index d2b9d1f974..6d78193233 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -72,6 +72,10 @@ from tensorflow.python.kernel_tests.gradient_checker import compute_gradient # pylint: enable=unused-import +# Import Benchmark class +Benchmark = googletest.Benchmark # pylint: disable=invalid-name + + def main(): """Runs all unit tests.""" return googletest.main() |