aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/benchmark
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-05-11 13:45:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 13:49:30 -0700
commit20be1353372dbb6f2db4b152a37453a2d4e0af14 (patch)
treeadefb0c0489472a25608c75204fa602d7a70359b /tensorflow/examples/benchmark
parent3c2dc3baaae762b00d90761f47265411f54033b3 (diff)
Added sample benchmark.
PiperOrigin-RevId: 155791315
Diffstat (limited to 'tensorflow/examples/benchmark')
-rw-r--r--tensorflow/examples/benchmark/BUILD31
-rw-r--r--tensorflow/examples/benchmark/sample_benchmark.py50
2 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/examples/benchmark/BUILD b/tensorflow/examples/benchmark/BUILD
new file mode 100644
index 0000000000..c4bb0a5bd9
--- /dev/null
+++ b/tensorflow/examples/benchmark/BUILD
@@ -0,0 +1,31 @@
+# Description:
+# Examples of adding a benchmark to TensorFlow.
+
+load(
+ "//tensorflow/tools/test:performance.bzl",
+ "tf_py_logged_benchmark",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_test(
+ name = "sample_benchmark",
+ srcs = ["sample_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+tf_py_logged_benchmark(
+ name = "sample_logged_benchmark",
+ target = "//tensorflow/examples/benchmark:sample_benchmark",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**/*"]),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/examples/benchmark/sample_benchmark.py b/tensorflow/examples/benchmark/sample_benchmark.py
new file mode 100644
index 0000000000..e98d7a2b5f
--- /dev/null
+++ b/tensorflow/examples/benchmark/sample_benchmark.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sample TensorFlow benchmark."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import tensorflow as tf
+
+
+# Define a class that extends from tf.test.Benchmark.
+class SampleBenchmark(tf.test.Benchmark):
+
+ # Note: benchmark method name must start with `benchmark`.
+ def benchmarkSum(self):
+ with tf.Session() as sess:
+ x = tf.constant(10)
+ y = tf.constant(5)
+ result = tf.add(x, y)
+
+ iters = 100
+ start_time = time.time()
+ for _ in range(iters):
+ sess.run(result)
+ total_wall_time = time.time() - start_time
+
+ # Call report_benchmark to report a metric value.
+ self.report_benchmark(
+ name="sum_wall_time",
+ # This value should always be per iteration.
+ wall_time=total_wall_time/iters,
+ iters=iters)
+
+
+if __name__ == "__main__":
+ tf.test.main()