aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-07 16:19:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 16:23:46 -0700
commit8d9e562d73d0c0fe6aa0ae70ea8c914dc1367592 (patch)
treedb04bab8dfb0bf74df55301c323c10e5720e24fb /tensorflow/python/data
parent75dec44b0aa24dc18f55925afd6eb12c103e1448 (diff)
[tf.data] Adding `use_inter_op_parallelism` attr to `ParallelMapDataset` and removing unused `graph_def_version` field
PiperOrigin-RevId: 212054031
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py32
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py10
3 files changed, 40 insertions, 3 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5cd1484084..631b87a718 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -156,6 +156,7 @@ tf_py_test(
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index df2c9b170a..fde785be6e 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -22,6 +22,7 @@ import threading
import time
import warnings
+from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import attr_value_pb2
@@ -46,7 +47,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test.TestCase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -705,6 +706,35 @@ class MapDatasetTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
sess.run(iterator.initializer)
+# pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Map", lambda dataset, func:
+ dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)),
+ ("ParallelMap", lambda dataset, func:
+ dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1,
+ use_inter_op_parallelism=False)),
+ )
+ def testNoInterOpParallelism(self, make_dataset_fn):
+ dataset = dataset_ops.Dataset.from_tensors(0)
+
+ def _get_tid():
+ return np.int64(threading.current_thread().ident)
+
+ def _map_fn(_):
+ tids = []
+ for _ in range(10):
+ tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
+ return tids
+
+ dataset = make_dataset_fn(dataset, _map_fn)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ tids = sess.run(get_next)
+ self.assertTrue(all(tids[0] == tid for tid in tids))
+# pylint: enable=g-long-lambda
+
class MapDatasetBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 2c1aa22116..c985e00dd1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2257,9 +2257,14 @@ class MapDataset(Dataset):
class ParallelMapDataset(MapDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
- def __init__(self, input_dataset, map_func, num_parallel_calls):
+ def __init__(self,
+ input_dataset,
+ map_func,
+ num_parallel_calls,
+ use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(ParallelMapDataset, self).__init__(input_dataset, map_func)
+ super(ParallelMapDataset, self).__init__(input_dataset, map_func,
+ use_inter_op_parallelism)
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
@@ -2272,6 +2277,7 @@ class ParallelMapDataset(MapDataset):
self._map_func.captured_inputs,
f=self._map_func,
num_parallel_calls=self._num_parallel_calls,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
# pylint: enable=protected-access