aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 09:26:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:35:51 -0700
commit1fb84c2e41c454939a02a69093cb214673eab343 (patch)
treee3ee1c19e3a73e1d1cddbc76d5573b7800b1048b /tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
parentac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (diff)
Add ability to vectorize nodes that do not derive from function arguments. (This indirectly handles "Const" outputs automagically, since they are always unstacked.)
PiperOrigin-RevId: 215749824
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 32ebc49c40..971a2d94b9 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -78,6 +78,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Basic", lambda x: (x, x + 1), None),
+ ("Const", lambda x: 2, 12),
("Parallel", lambda x: (x, x + 1), 12),
("Gather", lambda x: array_ops.gather(x, 0), 12),
)
@@ -207,6 +208,9 @@ class MapVectorizationBenchmark(test.Benchmark):
def benchmarkAddConst(self):
self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+ def benchmarkReturnConst(self):
+ self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const")
+
def benchmarkSelect(self):
self._benchmark_helper(lambda *args: args[0], "select")