diff options
author | 2018-09-07 16:19:47 -0700 | |
---|---|---|
committer | 2018-09-07 16:23:46 -0700 | |
commit | 8d9e562d73d0c0fe6aa0ae70ea8c914dc1367592 (patch) | |
tree | db04bab8dfb0bf74df55301c323c10e5720e24fb /tensorflow/python/data | |
parent | 75dec44b0aa24dc18f55925afd6eb12c103e1448 (diff) |
[tf.data] Adding `use_inter_op_parallelism` attr to `ParallelMapDataset` and removing unused `graph_def_version` field
PiperOrigin-RevId: 212054031
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/map_dataset_op_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 10 |
3 files changed, 40 insertions, 3 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5cd1484084..631b87a718 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -156,6 +156,7 @@ tf_py_test( size = "small", srcs = ["map_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index df2c9b170a..fde785be6e 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -22,6 +22,7 @@ import threading import time import warnings +from absl.testing import parameterized import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -46,7 +47,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class MapDatasetTest(test.TestCase): +class MapDatasetTest(test.TestCase, parameterized.TestCase): def _buildMapDataset(self, components, count): def _map_fn(x, y, z): @@ -705,6 +706,35 @@ class MapDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"): sess.run(iterator.initializer) +# pylint: disable=g-long-lambda + @parameterized.named_parameters( + ("Map", lambda dataset, func: + dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)), + ("ParallelMap", lambda dataset, func: + dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1, + use_inter_op_parallelism=False)), + ) + def testNoInterOpParallelism(self, make_dataset_fn): + dataset = dataset_ops.Dataset.from_tensors(0) + + def _get_tid(): + return np.int64(threading.current_thread().ident) + + def _map_fn(_): + tids = [] + for _ in range(10): + tids.append(script_ops.py_func(_get_tid, [], dtypes.int64)) + return tids + + dataset = make_dataset_fn(dataset, _map_fn) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + tids = sess.run(get_next) + self.assertTrue(all(tids[0] == tid for tid in tids)) +# pylint: enable=g-long-lambda + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2c1aa22116..c985e00dd1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2257,9 +2257,14 @@ class MapDataset(Dataset): class ParallelMapDataset(MapDataset): """A `Dataset` that maps a function over elements in its input in parallel.""" - def __init__(self, input_dataset, map_func, num_parallel_calls): + def __init__(self, + input_dataset, + map_func, + num_parallel_calls, + use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(ParallelMapDataset, self).__init__(input_dataset, map_func) + super(ParallelMapDataset, self).__init__(input_dataset, map_func, + use_inter_op_parallelism) self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") @@ -2272,6 +2277,7 @@ class ParallelMapDataset(MapDataset): self._map_func.captured_inputs, f=self._map_func, num_parallel_calls=self._num_parallel_calls, + use_inter_op_parallelism=self._use_inter_op_parallelism, **flat_structure(self)) # pylint: enable=protected-access |