aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-10-08 09:53:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:07:38 -0700
commitda3abf6afeaf781b932bce9ccb6c17da911e49b6 (patch)
treecfbcf8888da181403d2fe630b95d78156fec92f5 /tensorflow
parent07df147ab20c4a5329148e5fb5f7f6b187cb73a4 (diff)
Benchmark for comparing original cond and cond_v2 performance.
This benchmark creates many intermediates values, so we can make sure there's no performance overhead (it looks like there might be currently, or it might be from some other difference). It also runs in a defun and in legacy graph mode. Results from my machine: entry { name: "CondWithManyIntermediatesBenchmark.benchmark_cond_v1_defun" iters: 500 wall_time: 1.25822591782 } entry { name: "CondWithManyIntermediatesBenchmark.benchmark_cond_v2_defun" iters: 500 wall_time: 5.99376106262 } entry { name: "CondWithManyIntermediatesBenchmark.benchmark_cond_v1_graph" iters: 500 wall_time: 2.05277585983 } entry { name: "CondWithManyIntermediatesBenchmark.benchmark_cond_v2_graph" iters: 500 wall_time: 2.84808516502 } Clearly we have some work to do! I haven't looked into the time differences at all yet. PiperOrigin-RevId: 216202325
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/BUILD13
-rw-r--r--tensorflow/python/ops/control_flow_ops_benchmark.py122
2 files changed, 135 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index da3c56db92..822d596995 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5197,6 +5197,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "control_flow_ops_benchmark",
+ srcs = ["ops/control_flow_ops_benchmark.py"],
+ additional_deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":control_flow_ops",
+ ":framework_ops",
+ "//tensorflow/python/eager:function",
+ ],
+ main = "ops/control_flow_ops_benchmark.py",
+)
+
+cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py
new file mode 100644
index 0000000000..9ba5ff2c0f
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_benchmark.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for control flow ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class CondWithManyIntermediatesBenchmark(test.Benchmark):
+ """Checks the runtime performance of outputting all intermediates."""
+
+ NUM_INTERMEDIATES = 1000
+ NUM_ITERS = 500
+ NUM_WARM_UP_ITERS = 50
+
+ def _create_cond(self, x):
+
+ def branch_fn():
+ # Use a random value so the adds can't be constant folded.
+ return x + sum(random_ops.random_normal([])
+ for _ in range(self.NUM_INTERMEDIATES))
+
+ # Use a dynamic predicate to make sure the cond isn't constant folded.
+ return control_flow_ops.cond(math_ops.not_equal(x, -1),
+ branch_fn, lambda: 0.0)
+
+ def _benchmark_defun(self):
+ """Benchmarks cond in a defun."""
+
+ @function.defun
+ def cond_fn(x):
+ return self._create_cond(x)
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def _benchmark_graph(self):
+ """Benchmarks cond in legacy graph mode."""
+ with context.graph_mode():
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ cond_val = self._create_cond(x)
+
+ with session.Session() as sess:
+ cond_fn = sess.make_callable(cond_val, [x])
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def benchmark_cond_v1_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v1_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()