aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 13:46:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:51:07 -0700
commitb052c51374f558c25a29c70918d79205dfec808b (patch)
treefd2eb631b27c5facc6c27a17cfa3c6feb539136d /tensorflow/python/platform
parent76ab96c8a5b2d77dfc191c94ff54fd5e52c561f2 (diff)
Add tf.BenchmarkConfig that returns a session config appropriate for benchmarking. At the moment, it returns a default config with only Grappler dependency optimizer disabled. Many benchmarks wrap the subgraph they want to time in control_flow_ops.group() to avoid including the overhead of copying the output back to the Python client in the measurement. In the graph, this only adds a control dependency between the subgraph output and the fetch node, which in turn (often) causes the dependency optimizer to turn all nodes in the graph into no-ops.
PiperOrigin-RevId: 216242463
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r--tensorflow/python/platform/benchmark.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index fa17b17d10..4f7abb311a 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
from tensorflow.python.platform import app
@@ -182,6 +183,19 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
throughput=throughput, extras=extras)
+@tf_export("test.benchmark_config")
+def benchmark_config():
+ """Returns a tf.ConfigProto for disabling the dependency optimizer.
+
+ Returns:
+ A TensorFlow ConfigProto object.
+ """
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.dependency_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+
@tf_export("test.Benchmark")
class TensorFlowBenchmark(Benchmark):
"""Abstract class that provides helpers for TensorFlow benchmarks."""