diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-25 20:16:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 20:22:00 -0700 |
commit | 7f1d70d97f543d69a9f02cd6df0964f22f9278f3 (patch) | |
tree | 29612b6cd40203beba4f2b9689eef27a1f8da8d7 /tensorflow/contrib/distribute/python/values.py | |
parent | 3f4b8c138165cc9deb0ed931c5a6bb3d8ab556f0 (diff) |
Switching Distribution strategies to use MultiDeviceIterator. Currently only supported in Graph mode using initializable iterators. In a subsequent change, we'll add in support for Eager mode as well.
This removes prefetching_ops_v2 code.
PiperOrigin-RevId: 214546754
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 50 |
1 files changed, 38 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index fafa6384a1..a0cd029f51 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -26,7 +26,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -683,7 +683,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next(name=name) + data_list = self._iterator.get_next() index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -703,21 +703,26 @@ class PerDeviceDataIterator(object): class PerDeviceDataset(object): """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" - def __init__(self, dataset, devices, prefetch_on_device=None): + def __init__( + self, + dataset, + devices, + prefetch_on_device=None, + source_device="/cpu:0", + ): self._devices = devices + self._source_device = source_device if source_device is not None else "/cpu:0" # Default to using prefetching in graph mode, unless specified. - # TODO(priyag): Enable prefetching in eager mode. + # TODO(rohanj): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - if self._prefetch_on_device: - self._dataset = dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - else: + self._dataset = dataset + if not self._prefetch_on_device: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -725,15 +730,33 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" + # Graph mode prefetching with one shot iterator is disabled. + if not context.executing_eagerly(): + raise ValueError("Cannot create a one shot iterator. Please use " + "`make_initializable_iterator()` instead.") + # Eager mode prefetching would error out in constructor. Only remaining + # cases are non-prefetching eager / graph mode. We delegate to + # PerDeviceDataIterator to handle them. dataset_iterator = self._dataset.make_one_shot_iterator() return PerDeviceDataIterator( - dataset_iterator, self._devices, self._prefetch_on_device) + dataset_iterator, self._devices, prefetch_on_device=False) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - dataset_iterator = self._dataset.make_initializable_iterator() + # Eager mode generates already initialized iterators. Hence we cannot create + # an initializable iterator. + if context.executing_eagerly(): + raise ValueError("Cannot create initializable iterator in Eager mode. " + "Please use `make_one_shot_iterator` instead.") + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices, source_device=self._source_device) + else: + dataset_iterator = self._dataset.make_initializable_iterator() return PerDeviceDataIterator( - dataset_iterator, self._devices, self._prefetch_on_device) + dataset_iterator, + self._devices, + prefetch_on_device=self._prefetch_on_device) class MultiWorkerDataIterator(object): @@ -813,7 +836,10 @@ class MultiWorkerDataset(object): worker_input = input_ops.auto_shard_dataset( worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( - worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + worker_input, + worker_devices, + source_device=worker, + prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): iterators = {} |