diff options
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 971a2d94b9..803ff87924 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def testOptimizationWithCapturedInputs(self): # Tests that vectorization works with captured inputs + y = constant_op.constant(1, shape=(2,)) + z = constant_op.constant(2, shape=(2,)) + def map_fn(x): - return x + y + return x, y, z - y = constant_op.constant(1, shape=(2,)) base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) + base_dataset, map_fn, expect_optimized=True) self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): |