aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-28 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 14:32:29 -0700
commit1724d155f00b49bc817189247cbfb0df2092a9da (patch)
tree3f2606f84d779d8fca28a3d253c70176c4ed3fc1 /tensorflow/contrib/distribute/python/values.py
parent64be2ecc07c698df05d88051ec42a0409d1a9863 (diff)
Automated rollback of commit 7f1d70d97f543d69a9f02cd6df0964f22f9278f3
PiperOrigin-RevId: 214989908
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py51
1 files changed, 14 insertions, 37 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 327775a729..4955ded4d5 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next()
+ data_list = self._iterator.get_next(name=name)
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,24 +703,21 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(
- self,
- dataset,
- devices,
- prefetch_on_device=None,
- ):
+ def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
# Default to using prefetching in graph mode, unless specified.
- # TODO(rohanj): Enable prefetching in eager mode.
+ # TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- self._dataset = dataset
- if not self._prefetch_on_device:
+ if self._prefetch_on_device:
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -728,33 +725,15 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- # Graph mode prefetching with one shot iterator is disabled.
- if not context.executing_eagerly():
- raise ValueError("Cannot create a one shot iterator. Please use "
- "`make_initializable_iterator()` instead.")
- # Eager mode prefetching would error out in constructor. Only remaining
- # cases are non-prefetching eager / graph mode. We delegate to
- # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, prefetch_on_device=False)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- # Eager mode generates already initialized iterators. Hence we cannot create
- # an initializable iterator.
- if context.executing_eagerly():
- raise ValueError("Cannot create initializable iterator in Eager mode. "
- "Please use `make_one_shot_iterator` instead.")
- if self._prefetch_on_device:
- dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices)
- else:
- dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator,
- self._devices,
- prefetch_on_device=self._prefetch_on_device)
+ dataset_iterator = self._dataset.make_initializable_iterator()
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -837,9 +816,7 @@ class MultiWorkerDataset(object):
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input,
- worker_devices,
- prefetch_on_device=prefetch_on_device)
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}