diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-28 13:50:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 14:32:29 -0700 |
commit | 1724d155f00b49bc817189247cbfb0df2092a9da (patch) | |
tree | 3f2606f84d779d8fca28a3d253c70176c4ed3fc1 /tensorflow/contrib/distribute | |
parent | 64be2ecc07c698df05d88051ec42a0409d1a9863 (diff) |
Automated rollback of commit 7f1d70d97f543d69a9f02cd6df0964f22f9278f3
PiperOrigin-RevId: 214989908
Diffstat (limited to 'tensorflow/contrib/distribute')
13 files changed, 396 insertions, 89 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index e329b964c4..422983dbef 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -22,6 +22,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_ops", + ":prefetching_ops_v2", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", @@ -29,7 +30,6 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", @@ -648,6 +648,32 @@ cuda_py_test( ) py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_library( name = "input_ops", srcs = ["input_ops.py"], visibility = ["//tensorflow:internal"], diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index f7773aff4f..8163494c8e 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( - dataset_fn).make_initializable_iterator() + dataset_fn).make_one_shot_iterator() value, update = distribution.call_for_each_tower( metric_fn, iterator.get_next()) update = distribution.group(update) - self.evaluate(iterator.initializer) self.evaluate(variables.local_variables_initializer()) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_towers" with "1". diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index d082d5c419..ba147e7824 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): - def _get_iterator(self, ds): - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate(iterator.initializer) - return iterator - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, x, y, run_concurrently=False)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): initial_loss = lambda: constant_op.constant(1e7) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 93d42e09a2..4d7516063c 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -484,8 +484,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), - self._devices, + self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 04c712ce1d..f51e543624 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - ds = dist.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - features = iterator.get_next() + features = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + ).make_one_shot_iterator().get_next() with dist.scope(): result = dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 17b7ab74f6..7644acedc9 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,7 +51,6 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") - session.run(step_callable._iterator.initializer) # pylint: disable=protected-access self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 3064433129..6e9ba37a19 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - ds = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( @@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.cached_session() as sess: - sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000..8d949943b7 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,232 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for `tf.data.Iterator` that prefetches to another device. + + Args: + input_dataset: The input dataset. + one_shot: If true, we make a one shot iterator that's already initialized. + devices: Devices on which to prefetch. + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be shared + under the given name across multiple sessions that share the same devices + (e.g. when using a remote server). Only used if one_shot is False. + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + devices, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + self._devices = devices + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + target_device = ged_ops.experimental_iterator_get_device( + self._input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + output_types=data_nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)), + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + self._buffering_resources.append(buffer_resource_handle) + + if not self._one_shot: + reset_ops = [] + for buffer_resource in self._buffering_resources: + reset_ops.append( + ged_ops.experimental_function_buffering_resource_reset( + buffer_resource)) + with ops.control_dependencies(reset_ops): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See `tf.data.Iterator.get_next`.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = ged_ops.experimental_function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + devices=self._devices, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + if context.executing_eagerly(): + raise RuntimeError( + "make_initializable_iterator is not supported when eager " + "execution is enabled.") + + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + devices=self._devices, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a `tf.data.Dataset`, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000..16799104e8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,90 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. + with self.cached_session() as sess: + for _ in range(4): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(8)), set(output)) + + def testPrefetchToTwoDevicesWithReinit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. + with self.cached_session() as sess: + sess.run(iterator.initializer) + for _ in range(4): + sess.run(next_element) + sess.run(iterator.initializer) + for _ in range(4): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 23bf36184f..1b5a4f64e5 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -51,11 +50,7 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) self._distributed_input = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - self._iterator = self._distributed_input.make_one_shot_iterator() - else: - # TODO(priyag): Expose initializer via some initializer property. - self._iterator = self._distributed_input.make_initializable_iterator() + self._iterator = self._distributed_input.make_one_shot_iterator() class StandardSingleLossStep(StandardInputStep): diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec..f1ada49fa3 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): run_step = single_loss_step else: with self.cached_session() as sess: - sess.run(single_loss_step._iterator.initializer) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 327775a729..4955ded4d5 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -26,7 +26,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -683,7 +683,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next() + data_list = self._iterator.get_next(name=name) index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -703,24 +703,21 @@ class PerDeviceDataIterator(object): class PerDeviceDataset(object): """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" - def __init__( - self, - dataset, - devices, - prefetch_on_device=None, - ): + def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices # Default to using prefetching in graph mode, unless specified. - # TODO(rohanj): Enable prefetching in eager mode. + # TODO(priyag): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - self._dataset = dataset - if not self._prefetch_on_device: + if self._prefetch_on_device: + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -728,33 +725,15 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" - # Graph mode prefetching with one shot iterator is disabled. - if not context.executing_eagerly(): - raise ValueError("Cannot create a one shot iterator. Please use " - "`make_initializable_iterator()` instead.") - # Eager mode prefetching would error out in constructor. Only remaining - # cases are non-prefetching eager / graph mode. We delegate to - # PerDeviceDataIterator to handle them. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( - dataset_iterator, self._devices, prefetch_on_device=False) + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - # Eager mode generates already initialized iterators. Hence we cannot create - # an initializable iterator. - if context.executing_eagerly(): - raise ValueError("Cannot create initializable iterator in Eager mode. " - "Please use `make_one_shot_iterator` instead.") - if self._prefetch_on_device: - dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices) - else: - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( - dataset_iterator, - self._devices, - prefetch_on_device=self._prefetch_on_device) + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) class MultiWorkerDataIterator(object): @@ -837,9 +816,7 @@ class MultiWorkerDataset(object): worker_input = input_ops.auto_shard_dataset( worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( - worker_input, - worker_devices, - prefetch_on_device=prefetch_on_device) + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): iterators = {} diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 002d61f46e..ae3e134333 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase): def _test_iterator_no_prefetch(self, devices, dataset, expected_values): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) - if context.executing_eagerly(): - iterator = per_device_dataset.make_one_shot_iterator() - else: - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) + iterator = per_device_dataset.make_one_shot_iterator() for expected_value in expected_values: next_element = iterator.get_next() @@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) + iterator = per_device_dataset.make_one_shot_iterator() + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + combined_actual.extend( + self.evaluate( + [values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() |