diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-09 10:40:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 10:45:19 -0700 |
commit | 11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 (patch) | |
tree | 032a600b39f926c9ec6ab62625b5a5fd03f20e87 /tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py | |
parent | 1b4402137a76c8085c160edfcc0c3be3cfa8fa3a (diff) |
[tf.data vectorization] Handle captured inputs in MapVectorization optimization
PiperOrigin-RevId: 216381943
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 971a2d94b9..803ff87924 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def testOptimizationWithCapturedInputs(self): # Tests that vectorization works with captured inputs + y = constant_op.constant(1, shape=(2,)) + z = constant_op.constant(2, shape=(2,)) + def map_fn(x): - return x + y + return x, y, z - y = constant_op.constant(1, shape=(2,)) base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) + base_dataset, map_fn, expect_optimized=True) self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): |