diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-28 13:50:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 14:32:29 -0700 |
commit | 1724d155f00b49bc817189247cbfb0df2092a9da (patch) | |
tree | 3f2606f84d779d8fca28a3d253c70176c4ed3fc1 /tensorflow/contrib/distribute/python/values.py | |
parent | 64be2ecc07c698df05d88051ec42a0409d1a9863 (diff) |
Automated rollback of commit 7f1d70d97f543d69a9f02cd6df0964f22f9278f3
PiperOrigin-RevId: 214989908
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 51 |
1 files changed, 14 insertions, 37 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 327775a729..4955ded4d5 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -26,7 +26,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -683,7 +683,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next() + data_list = self._iterator.get_next(name=name) index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -703,24 +703,21 @@ class PerDeviceDataIterator(object): class PerDeviceDataset(object): """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" - def __init__( - self, - dataset, - devices, - prefetch_on_device=None, - ): + def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices # Default to using prefetching in graph mode, unless specified. - # TODO(rohanj): Enable prefetching in eager mode. + # TODO(priyag): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - self._dataset = dataset - if not self._prefetch_on_device: + if self._prefetch_on_device: + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -728,33 +725,15 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" - # Graph mode prefetching with one shot iterator is disabled. - if not context.executing_eagerly(): - raise ValueError("Cannot create a one shot iterator. Please use " - "`make_initializable_iterator()` instead.") - # Eager mode prefetching would error out in constructor. Only remaining - # cases are non-prefetching eager / graph mode. We delegate to - # PerDeviceDataIterator to handle them. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( - dataset_iterator, self._devices, prefetch_on_device=False) + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - # Eager mode generates already initialized iterators. Hence we cannot create - # an initializable iterator. - if context.executing_eagerly(): - raise ValueError("Cannot create initializable iterator in Eager mode. " - "Please use `make_one_shot_iterator` instead.") - if self._prefetch_on_device: - dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices) - else: - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( - dataset_iterator, - self._devices, - prefetch_on_device=self._prefetch_on_device) + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) class MultiWorkerDataIterator(object): @@ -837,9 +816,7 @@ class MultiWorkerDataset(object): worker_input = input_ops.auto_shard_dataset( worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( - worker_input, - worker_devices, - prefetch_on_device=prefetch_on_device) + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): iterators = {} |