diff options
4 files changed, 20 insertions, 133 deletions
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD deleted file mode 100644 index 638c57d1c9..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -cuda_py_test( - name = "scan_test", - size = "small", - srcs = ["scan_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "scan_graph_test", - size = "small", - srcs = ["scan_graph_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py deleted file mode 100644 index d4b8c8941e..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under graph mode execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - with tf.Session() as sess: - sess.run(sum_op) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py deleted file mode 100644 index a02fc24c79..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under eager execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - - -if __name__ == '__main__': - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 3bdaf0b214..3fe79ef244 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark): assert np.equal(func(), make_keras_model()(data)).all() self._run(func, 30000) + def benchmarkScan(self): + elems = math_ops.range(1600) + + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + + def benchmarkScanDefun(self): + elems = math_ops.range(1600) + + @function.defun + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + if __name__ == "__main__": test.main() |