aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/BUILD25
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py54
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_test.py54
-rw-r--r--tensorflow/python/eager/benchmarks_test.py20
4 files changed, 20 insertions, 133 deletions
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
deleted file mode 100644
index 638c57d1c9..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-cuda_py_test(
- name = "scan_test",
- size = "small",
- srcs = ["scan_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "scan_graph_test",
- size = "small",
- srcs = ["scan_graph_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
deleted file mode 100644
index d4b8c8941e..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under graph mode execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- with tf.Session() as sess:
- sess.run(sum_op)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
deleted file mode 100644
index a02fc24c79..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under eager execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-
-if __name__ == '__main__':
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3bdaf0b214..3fe79ef244 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
+ def benchmarkScan(self):
+ elems = math_ops.range(1600)
+
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
+ def benchmarkScanDefun(self):
+ elems = math_ops.range(1600)
+
+ @function.defun
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
if __name__ == "__main__":
test.main()