diff options
Diffstat (limited to 'tensorflow/python/eager/benchmarks_test.py')
-rw-r--r-- | tensorflow/python/eager/benchmarks_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 3bdaf0b214..3fe79ef244 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark): assert np.equal(func(), make_keras_model()(data)).all() self._run(func, 30000) + def benchmarkScan(self): + elems = math_ops.range(1600) + + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + + def benchmarkScanDefun(self): + elems = math_ops.range(1600) + + @function.defun + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + if __name__ == "__main__": test.main() |