aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform/benchmark.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-03-10 16:41:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 17:10:03 -0800
commitee8b0fe9ee6ef46f46b5517f57f6cf85ed13b978 (patch)
treef87cb595e529e37dd658b00ac72629cb5df915be /tensorflow/python/platform/benchmark.py
parentba7dd8ae0ae01a94826f4e21e38d6e5d12979915 (diff)
Allow external code to disable run_op_benchmark's memory logging.
Change: 149812908
Diffstat (limited to 'tensorflow/python/platform/benchmark.py')
-rw-r--r--tensorflow/python/platform/benchmark.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index d91c19eeb4..ea29399ed2 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -42,6 +42,8 @@ GLOBAL_BENCHMARK_REGISTRY = set()
# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
+_benchmark_tests_can_log_memory = app._benchmark_tests_can_log_memory # pylint: disable=protected-access
+
def _global_report_benchmark(
name, iters=None, cpu_time=None, wall_time=None,
@@ -212,8 +214,9 @@ class TensorFlowBenchmark(Benchmark):
store the trace of iteration in the benchmark report.
The trace will be stored as a string in Google Chrome trace format
in the extras field "full_trace_chrome_format".
- store_memory_usage: Boolean, whether to run an extra untimed iteration,
- calculate memory usage, and store that in extras fields.
+ store_memory_usage: Boolean, whether to run an extra
+ untimed iteration, calculate memory usage, and store that in extras
+ fields.
name: (optional) Override the BenchmarkEntry name with `name`.
Otherwise it is inferred from the top-level method name.
extras: (optional) Dict mapping string keys to additional benchmark info.
@@ -225,6 +228,8 @@ class TensorFlowBenchmark(Benchmark):
A `dict` containing the key-value pairs that were passed to
`report_benchmark`.
"""
+ store_memory_usage &= _benchmark_tests_can_log_memory()
+
for _ in range(burn_iters):
sess.run(op_or_tensor, feed_dict=feed_dict)