aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-10-08 10:14:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:23:50 -0700
commit0e42fd6d0a88b30ab57959f38c79bea19d745ec3 (patch)
tree177a8662e421aa2c4d57d1b1caf370077de0b55c /tensorflow/python
parent0e1ba8886b6a333b1ed8ed7548c55041c34e9623 (diff)
[tf.data] Adding specialization for `MapDataset`, `ParallelMapDataset`, and `MapAndBatchDataset` whose user-provided functions have the property that each output argument take its value directly from an input argument (e.g. `lambda x, y: y, x`). This specialization can produce the result without having to schedule the function using the executor.
PiperOrigin-RevId: 216206232
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py31
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py95
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py29
4 files changed, 141 insertions, 16 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index afd0fc3abf..d444c4082e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+ @parameterized.named_parameters(
+ ("Identity", None, lambda x: x, None),
+ ("Replicate", None, lambda x: (x, x), None),
+ ("Swap", (None, None), lambda x, y: (y, x), None),
+ ("Project", (None, None), lambda x, y: x, None),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(
+ *sess.run(self.structuredElement(structure, shape=[10])))
+ else:
+ expected = map_fn(
+ sess.run(self.structuredElement(structure, shape=[10])))
+ self.assertAllEqual(expected, sess.run(get_next))
+
+ def testShortCircuitCapturedInput(self):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().apply(
+ batching.map_and_batch(lambda x: captured_t, batch_size=10))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertAllEqual([42] * 10, sess.run(get_next))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 6b7afafa5d..a0c6b37a6d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -156,7 +156,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testReturnComponent(self):
+ def testShortCircuit(self):
iterator = (
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 0c372ebb10..4683b1db91 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -622,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -649,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -783,19 +783,72 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
+ @parameterized.named_parameters(
+ ("SequentialIdentity", None, lambda x: x, None),
+ ("SequentialReplicate", None, lambda x: (x, x), None),
+ ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
+ ("SequentialProject", (None, None), lambda x, y: x, None),
+ ("ParallelIdentity", None, lambda x: x, 10),
+ ("ParallelReplicate", None, lambda x: (x, x), 10),
+ ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
+ ("ParallelProject", (None, None), lambda x, y: x, 10),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().map(
+ map_fn, num_parallel_calls=num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(*sess.run(self.structuredElement(structure)))
+ else:
+ expected = map_fn(sess.run(self.structuredElement(structure)))
+ self.assertEqual(expected, sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Sequential", None),
+ ("Parallel", 10),
+ )
+ def testShortCircuitCapturedInput(self, num_parallel_calls):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().map(
+ lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertEqual(42, sess.run(get_next))
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- lambda x: x,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -813,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", chain_length, median_wall_time))
+ (print_label, chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- lambda *xs: xs,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -849,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", fan_out, median_wall_time))
+ (print_label, fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" %
- (fan_out, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
+ benchmark_label))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b730e10949..b73a94e683 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -19,10 +19,13 @@ from __future__ import print_function
import re
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase):
with self.assertRaisesRegexp(exception_class,
re.escape(expected_message)):
self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])