diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-23 08:53:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 08:58:05 -0700 |
commit | f1237459efb3a5578885b03d5b33c3fed350c348 (patch) | |
tree | 128b49a2112e773640a06b3b0578afa8cddd81f0 /tensorflow/python/data | |
parent | 646b3c237deaddddd087d39ab57130b08375c4c7 (diff) |
Moving MultiDeviceIterator from contrib to core.
PiperOrigin-RevId: 214173896
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/multi_device_iterator_test.py | 190 | ||||
-rw-r--r-- | tensorflow/python/data/ops/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/python/data/ops/multi_device_iterator_ops.py | 213 |
5 files changed, 442 insertions, 0 deletions
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD index 3e08c1587e..138141f4fc 100644 --- a/tensorflow/python/data/BUILD +++ b/tensorflow/python/data/BUILD @@ -12,6 +12,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/data/ops:readers", ], ) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 17d4fec662..f97116cadd 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -408,6 +408,26 @@ cuda_py_test( ], ) +cuda_py_test( + name = "multi_device_iterator_test", + size = "small", + srcs = ["multi_device_iterator_test.py"], + additional_deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:multi_device_iterator_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + ], + tags = [ + "no_windows_gpu", + ], +) + tf_py_test( name = "window_dataset_op_test", size = "small", diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py new file mode 100644 index 0000000000..056664b83b --- /dev/null +++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MultiDeviceIterator tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MultiDeviceIteratorTest(test.TestCase): + + def testNoGetNext(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + + def testBasic(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testOneOnSameDevice(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:0", "/cpu:1"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testRepeatDevices(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(20) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) + elements = multi_device_iterator.get_next() + elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 20, 4): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(i + 2, sess.run(elem_on_3)) + self.assertEqual(i + 3, sess.run(elem_on_4)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + sess.run(elem_on_3) + sess.run(elem_on_4) + + def testNotFullyDivisible(self): + dataset = dataset_ops.Dataset.range(9) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 8, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(8, sess.run(elem_on_1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUneven(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testMultipleInitializations(self): + with ops.device("/cpu:0"): + epoch = array_ops.placeholder(dtypes.int64, shape=[]) + dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) + dataset2 = dataset_ops.Dataset.range(1000) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + init_op = multi_device_iterator.initializer + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + for i in range(1000): + sess.run(init_op, feed_dict={epoch: i}) + self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2])) + + def testBasicGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUnevenGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 57517afae8..9dffc38820 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -19,6 +19,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:random_seed", "//tensorflow/python:script_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", @@ -83,3 +84,20 @@ py_library( "//tensorflow/python/data/util:sparse", ], ) + +py_library( + name = "multi_device_iterator_ops", + srcs = ["multi_device_iterator_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py new file mode 100644 index 0000000000..84e8abbd83 --- /dev/null +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrapper for prefetching_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops + + +class _PerDeviceGenerator(dataset_ops.Dataset): + """A `dummy` generator dataset.""" + + def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, + source_device, target_device, output_shapes, output_types, + output_classes): + self._target_device = target_device + self._output_types = output_types + self._output_shapes = output_shapes + self._output_classes = output_classes + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + + multi_device_iterator_string_handle = ( + gen_dataset_ops.multi_device_iterator_to_string_handle( + multi_device_iterator_resource)) + + @function.Defun() + def _init_func(): + return multi_device_iterator_string_handle + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + multi_device_iterator = ( + gen_dataset_ops.multi_device_iterator_from_string_handle( + string_handle=string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + return gen_dataset_ops.multi_device_iterator_get_next_from_shard( + multi_device_iterator=multi_device_iterator, + shard_num=shard_num, + incarnation_id=incarnation_id, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(unused_string_handle): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + +class MultiDeviceIterator(object): + """An iterator over multiple devices.""" + + def __init__(self, + dataset, + devices, + max_buffer_size=1, + prefetch_buffer_size=1, + source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + max_buffer_size: Maximum size of the host side per device buffer to keep. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ + self._dataset = dataset + self._devices = devices + self._source_device = source_device + self._source_device_tensor = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._dataset.output_shapes, + self._dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._dataset.output_types, + self._dataset.output_classes)) + + # Create the MultiDeviceIterator. + with ops.device(self._source_device): + self._multi_device_iterator_resource = ( + gen_dataset_ops.multi_device_iterator( + devices=self._devices, + shared_name="", + container="", + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + + # The incarnation ID is used to ensure consistency between the per-device + # iterators and the multi-device iterator. + self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( + self._dataset._as_variant_tensor(), # pylint: disable=protected-access + self._multi_device_iterator_resource, + max_buffer_size=max_buffer_size) + + # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to + # initialize the device side of the pipeline. This would allow the + # MultiDeviceIterator to choose, for example, to move some transformations + # into the device side from its input. It might be useful in rewriting. + # Create the per device iterators. + self._device_iterators = [] + i = 0 + for device in self._devices: + ds = _PerDeviceGenerator( + i, self._multi_device_iterator_resource, self._incarnation_id, + self._source_device_tensor, device, self._dataset.output_shapes, + self._dataset.output_types, self._dataset.output_classes) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) + with ops.device(device): + self._device_iterators.append(ds.make_initializable_iterator()) + i += 1 + + device_iterator_initializers = [ + iterator.initializer for iterator in self._device_iterators + ] + self._initializer = control_flow_ops.group(*device_iterator_initializers) + + def get_next(self): + result = [] + i = 0 + for device in self._devices: + with ops.device(device): + result.append(self._device_iterators[i].get_next()) + i += 1 + return result + + @property + def initializer(self): + return self._initializer |