diff options
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index afd0fc3abf..d444c4082e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) + @parameterized.named_parameters( + ("Identity", None, lambda x: x, None), + ("Replicate", None, lambda x: (x, x), None), + ("Swap", (None, None), lambda x, y: (y, x), None), + ("Project", (None, None), lambda x, y: x, None), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().apply( + batching.map_and_batch(map_fn, batch_size=10)) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn( + *sess.run(self.structuredElement(structure, shape=[10]))) + else: + expected = map_fn( + sess.run(self.structuredElement(structure, shape=[10]))) + self.assertAllEqual(expected, sess.run(get_next)) + + def testShortCircuitCapturedInput(self): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().apply( + batching.map_and_batch(lambda x: captured_t, batch_size=10)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertAllEqual([42] * 10, sess.run(get_next)) + if __name__ == "__main__": test.main() |