From 0e42fd6d0a88b30ab57959f38c79bea19d745ec3 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 8 Oct 2018 10:14:58 -0700 Subject: [tf.data] Adding specialization for `MapDataset`, `ParallelMapDataset`, and `MapAndBatchDataset` whose user-provided functions have the property that each output argument take its value directly from an input argument (e.g. `lambda x, y: y, x`). This specialization can produce the result without having to schedule the function using the executor. PiperOrigin-RevId: 216206232 --- .../kernel_tests/map_and_batch_test.py | 31 +++++++ .../data/kernel_tests/filter_dataset_op_test.py | 2 +- .../data/kernel_tests/map_dataset_op_test.py | 95 ++++++++++++++++++---- tensorflow/python/data/kernel_tests/test_base.py | 29 +++++++ 4 files changed, 141 insertions(+), 16 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index afd0fc3abf..d444c4082e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) + @parameterized.named_parameters( + ("Identity", None, lambda x: x, None), + ("Replicate", None, lambda x: (x, x), None), + ("Swap", (None, None), lambda x, y: (y, x), None), + ("Project", (None, None), lambda x, y: x, None), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().apply( + batching.map_and_batch(map_fn, batch_size=10)) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn( + *sess.run(self.structuredElement(structure, shape=[10]))) + else: + expected = map_fn( + sess.run(self.structuredElement(structure, shape=[10]))) + self.assertAllEqual(expected, sess.run(get_next)) + + def testShortCircuitCapturedInput(self): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().apply( + batching.map_and_batch(lambda x: captured_t, batch_size=10)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertAllEqual([42] * 10, sess.run(get_next)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 6b7afafa5d..a0c6b37a6d 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -156,7 +156,7 @@ class FilterDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testReturnComponent(self): + def testShortCircuit(self): iterator = ( dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(10), diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 0c372ebb10..4683b1db91 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -622,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _sparse(i)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -649,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -783,19 +783,72 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(all(tids[0] == tid for tid in tids)) # pylint: enable=g-long-lambda + @parameterized.named_parameters( + ("SequentialIdentity", None, lambda x: x, None), + ("SequentialReplicate", None, lambda x: (x, x), None), + ("SequentialSwap", (None, None), lambda x, y: (y, x), None), + ("SequentialProject", (None, None), lambda x, y: x, None), + ("ParallelIdentity", None, lambda x: x, 10), + ("ParallelReplicate", None, lambda x: (x, x), 10), + ("ParallelSwap", (None, None), lambda x, y: (y, x), 10), + ("ParallelProject", (None, None), lambda x, y: x, 10), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().map( + map_fn, num_parallel_calls=num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn(*sess.run(self.structuredElement(structure))) + else: + expected = map_fn(sess.run(self.structuredElement(structure))) + self.assertEqual(expected, sess.run(get_next)) + + @parameterized.named_parameters( + ("Sequential", None), + ("Parallel", 10), + ) + def testShortCircuitCapturedInput(self, num_parallel_calls): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().map( + lambda x: captured_t, num_parallel_calls=num_parallel_calls) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertEqual(42, sess.run(get_next)) + class MapDatasetBenchmark(test.Benchmark): def benchmarkChainOfMaps(self): chain_lengths = [0, 1, 2, 5, 10, 20, 50] for chain_length in chain_lengths: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda x: x + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) for _ in range(chain_length): dataset = dataset_ops.MapDataset( dataset, - lambda x: x, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -813,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset chain length%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", chain_length, median_wall_time)) + (print_label, chain_length, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, name="benchmark_map_dataset_chain_latency_%d%s" % - (chain_length, "_single_threaded" - if not use_inter_op_parallelism else "")) + (chain_length, benchmark_label)) def benchmarkMapFanOut(self): fan_outs = [1, 2, 5, 10, 20, 50, 100] for fan_out in fan_outs: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda *xs: xs + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors( tuple(0 for _ in range(fan_out))).repeat(None) dataset = dataset_ops.MapDataset( dataset, - lambda *xs: xs, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -849,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset fan out%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", fan_out, median_wall_time)) + (print_label, fan_out, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_fan_out_%d%s" % - (fan_out, "_single_threaded" - if not use_inter_op_parallelism else "")) + name="benchmark_map_dataset_fan_out_%d%s" % (fan_out, + benchmark_label)) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b730e10949..b73a94e683 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -19,10 +19,13 @@ from __future__ import print_function import re +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase): with self.assertRaisesRegexp(exception_class, re.escape(expected_message)): self.evaluate(next2()) + + def structuredDataset(self, structure, shape=None, dtype=dtypes.int64): + """Returns a singleton dataset with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self.structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def structuredElement(self, structure, shape=None, dtype=dtypes.int64): + """Returns an element with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self.structuredElement(substructure, shape, dtype) + for substructure in structure + ]) -- cgit v1.2.3