aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-09 10:40:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 10:45:19 -0700
commit11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 (patch)
tree032a600b39f926c9ec6ab62625b5a5fd03f20e87 /tensorflow/python
parent1b4402137a76c8085c160edfcc0c3be3cfa8fa3a (diff)
[tf.data vectorization] Handle captured inputs in MapVectorization optimization
PiperOrigin-RevId: 216381943
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 971a2d94b9..803ff87924 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def testOptimizationWithCapturedInputs(self):
# Tests that vectorization works with captured inputs
+ y = constant_op.constant(1, shape=(2,))
+ z = constant_op.constant(2, shape=(2,))
+
def map_fn(x):
- return x + y
+ return x, y, z
- y = constant_op.constant(1, shape=(2,))
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
- # TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
+ base_dataset, map_fn, expect_optimized=True)
self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):