From 11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Tue, 9 Oct 2018 10:40:23 -0700 Subject: [tf.data vectorization] Handle captured inputs in MapVectorization optimization PiperOrigin-RevId: 216381943 --- .../kernel_tests/optimization/map_vectorization_test.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 971a2d94b9..803ff87924 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def testOptimizationWithCapturedInputs(self): # Tests that vectorization works with captured inputs + y = constant_op.constant(1, shape=(2,)) + z = constant_op.constant(2, shape=(2,)) + def map_fn(x): - return x + y + return x, y, z - y = constant_op.constant(1, shape=(2,)) base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) + base_dataset, map_fn, expect_optimized=True) self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): -- cgit v1.2.3