aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-23 08:53:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 08:58:05 -0700
commitf1237459efb3a5578885b03d5b33c3fed350c348 (patch)
tree128b49a2112e773640a06b3b0578afa8cddd81f0 /tensorflow/python/data
parent646b3c237deaddddd087d39ab57130b08375c4c7 (diff)
Moving MultiDeviceIterator from contrib to core.
PiperOrigin-RevId: 214173896
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD20
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py190
-rw-r--r--tensorflow/python/data/ops/BUILD18
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py213
5 files changed, 442 insertions, 0 deletions
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 3e08c1587e..138141f4fc 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -12,6 +12,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 17d4fec662..f97116cadd 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -408,6 +408,26 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "small",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
tf_py_test(
name = "window_dataset_op_test",
size = "small",
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
new file mode 100644
index 0000000000..056664b83b
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -0,0 +1,190 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MultiDeviceIterator tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MultiDeviceIteratorTest(test.TestCase):
+
+ def testNoGetNext(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+
+ def testBasic(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testOneOnSameDevice(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:0", "/cpu:1"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testRepeatDevices(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(20)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
+ elements = multi_device_iterator.get_next()
+ elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 20, 4):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(i + 2, sess.run(elem_on_3))
+ self.assertEqual(i + 3, sess.run(elem_on_4))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+ sess.run(elem_on_3)
+ sess.run(elem_on_4)
+
+ def testNotFullyDivisible(self):
+ dataset = dataset_ops.Dataset.range(9)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 8, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(8, sess.run(elem_on_1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUneven(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testMultipleInitializations(self):
+ with ops.device("/cpu:0"):
+ epoch = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
+ dataset2 = dataset_ops.Dataset.range(1000)
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+ init_op = multi_device_iterator.initializer
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ for i in range(1000):
+ sess.run(init_op, feed_dict={epoch: i})
+ self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
+
+ def testBasicGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUnevenGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 57517afae8..9dffc38820 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
@@ -83,3 +84,20 @@ py_library(
"//tensorflow/python/data/util:sparse",
],
)
+
+py_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
new file mode 100644
index 0000000000..84e8abbd83
--- /dev/null
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -0,0 +1,213 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class _PerDeviceGenerator(dataset_ops.Dataset):
+ """A `dummy` generator dataset."""
+
+ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
+ source_device, target_device, output_shapes, output_types,
+ output_classes):
+ self._target_device = target_device
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+ self._output_classes = output_classes
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+
+ multi_device_iterator_string_handle = (
+ gen_dataset_ops.multi_device_iterator_to_string_handle(
+ multi_device_iterator_resource))
+
+ @function.Defun()
+ def _init_func():
+ return multi_device_iterator_string_handle
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ multi_device_iterator = (
+ gen_dataset_ops.multi_device_iterator_from_string_handle(
+ string_handle=string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+ return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
+ multi_device_iterator=multi_device_iterator,
+ shard_num=shard_num,
+ incarnation_id=incarnation_id,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(unused_string_handle):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+class MultiDeviceIterator(object):
+ """An iterator over multiple devices."""
+
+ def __init__(self,
+ dataset,
+ devices,
+ max_buffer_size=1,
+ prefetch_buffer_size=1,
+ source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ max_buffer_size: Maximum size of the host side per device buffer to keep.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+ """
+ self._dataset = dataset
+ self._devices = devices
+ self._source_device = source_device
+ self._source_device_tensor = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._dataset.output_shapes,
+ self._dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._dataset.output_types,
+ self._dataset.output_classes))
+
+ # Create the MultiDeviceIterator.
+ with ops.device(self._source_device):
+ self._multi_device_iterator_resource = (
+ gen_dataset_ops.multi_device_iterator(
+ devices=self._devices,
+ shared_name="",
+ container="",
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+
+ # The incarnation ID is used to ensure consistency between the per-device
+ # iterators and the multi-device iterator.
+ self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
+ self._dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._multi_device_iterator_resource,
+ max_buffer_size=max_buffer_size)
+
+ # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
+ # initialize the device side of the pipeline. This would allow the
+ # MultiDeviceIterator to choose, for example, to move some transformations
+ # into the device side from its input. It might be useful in rewriting.
+ # Create the per device iterators.
+ self._device_iterators = []
+ i = 0
+ for device in self._devices:
+ ds = _PerDeviceGenerator(
+ i, self._multi_device_iterator_resource, self._incarnation_id,
+ self._source_device_tensor, device, self._dataset.output_shapes,
+ self._dataset.output_types, self._dataset.output_classes)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
+ with ops.device(device):
+ self._device_iterators.append(ds.make_initializable_iterator())
+ i += 1
+
+ device_iterator_initializers = [
+ iterator.initializer for iterator in self._device_iterators
+ ]
+ self._initializer = control_flow_ops.group(*device_iterator_initializers)
+
+ def get_next(self):
+ result = []
+ i = 0
+ for device in self._devices:
+ with ops.device(device):
+ result.append(self._device_iterators[i].get_next())
+ i += 1
+ return result
+
+ @property
+ def initializer(self):
+ return self._initializer