aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/map_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py80
1 files changed, 15 insertions, 65 deletions
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 6efbe31ca1..0c372ebb10 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -622,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
+ self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -649,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
+ self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -783,57 +783,19 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
- @parameterized.named_parameters(
- ("SequentialIdentity", None, lambda x: x, None),
- ("SequentialReplicate", None, lambda x: (x, x), None),
- ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
- ("SequentialProject", (None, None), lambda x, y: x, None),
- ("ParallelIdentity", None, lambda x: x, 10),
- ("ParallelReplicate", None, lambda x: (x, x), 10),
- ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
- ("ParallelProject", (None, None), lambda x, y: x, 10),
- )
- def testShortCircuit(self, structure, map_fn, num_parallel_calls):
- dataset = self.structuredDataset(structure).repeat().map(
- map_fn, num_parallel_calls=num_parallel_calls)
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- if isinstance(structure, tuple):
- expected = map_fn(*sess.run(self.structuredElement(structure)))
- else:
- expected = map_fn(sess.run(self.structuredElement(structure)))
- self.assertEqual(expected, sess.run(get_next))
-
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = True
- print_label = ""
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = False
- print_label = " (single threaded mode)"
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda x: x
- use_inter_op_parallelism = True # should not have any significance
- print_label = " (short circuit mode)"
- benchmark_label = "_short_circuit"
-
+ for use_inter_op_parallelism in [False, True]:
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- map_fn,
+ lambda x: x,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -851,39 +813,25 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (print_label, chain_length, median_wall_time))
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, benchmark_label))
+ (chain_length, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = True
- print_label = ""
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = False
- print_label = " (single threaded mode)"
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda *xs: xs
- use_inter_op_parallelism = True # should not have any significance
- print_label = " (short circuit mode)"
- benchmark_label = "_short_circuit"
-
+ for use_inter_op_parallelism in [False, True]:
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- map_fn,
+ lambda *xs: xs,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -901,12 +849,14 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (print_label, fan_out, median_wall_time))
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
- benchmark_label))
+ name="benchmark_map_dataset_fan_out_%d%s" %
+ (fan_out, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
if __name__ == "__main__":