aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index afd0fc3abf..d444c4082e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+ @parameterized.named_parameters(
+ ("Identity", None, lambda x: x, None),
+ ("Replicate", None, lambda x: (x, x), None),
+ ("Swap", (None, None), lambda x, y: (y, x), None),
+ ("Project", (None, None), lambda x, y: x, None),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(
+ *sess.run(self.structuredElement(structure, shape=[10])))
+ else:
+ expected = map_fn(
+ sess.run(self.structuredElement(structure, shape=[10])))
+ self.assertAllEqual(expected, sess.run(get_next))
+
+ def testShortCircuitCapturedInput(self):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().apply(
+ batching.map_and_batch(lambda x: captured_t, batch_size=10))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertAllEqual([42] * 10, sess.run(get_next))
+
if __name__ == "__main__":
test.main()