aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/benchmarks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/benchmarks_test.py')
-rw-r--r--tensorflow/python/eager/benchmarks_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3bdaf0b214..3fe79ef244 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
+ def benchmarkScan(self):
+ elems = math_ops.range(1600)
+
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
+ def benchmarkScanDefun(self):
+ elems = math_ops.range(1600)
+
+ @function.defun
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
if __name__ == "__main__":
test.main()