aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-25 20:16:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 20:22:00 -0700
commit7f1d70d97f543d69a9f02cd6df0964f22f9278f3 (patch)
tree29612b6cd40203beba4f2b9689eef27a1f8da8d7 /tensorflow/contrib/distribute/python/values.py
parent3f4b8c138165cc9deb0ed931c5a6bb3d8ab556f0 (diff)
Switching Distribution strategies to use MultiDeviceIterator. Currently only supported in Graph mode using initializable iterators. In a subsequent change, we'll add in support for Eager mode as well.
This removes prefetching_ops_v2 code. PiperOrigin-RevId: 214546754
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py50
1 files changed, 38 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index fafa6384a1..a0cd029f51 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next(name=name)
+ data_list = self._iterator.get_next()
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,21 +703,26 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(self, dataset, devices, prefetch_on_device=None):
+ def __init__(
+ self,
+ dataset,
+ devices,
+ prefetch_on_device=None,
+ source_device="/cpu:0",
+ ):
self._devices = devices
+ self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
- # TODO(priyag): Enable prefetching in eager mode.
+ # TODO(rohanj): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- if self._prefetch_on_device:
- self._dataset = dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(self._devices))
- else:
+ self._dataset = dataset
+ if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -725,15 +730,33 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
+ # Graph mode prefetching with one shot iterator is disabled.
+ if not context.executing_eagerly():
+ raise ValueError("Cannot create a one shot iterator. Please use "
+ "`make_initializable_iterator()` instead.")
+ # Eager mode prefetching would error out in constructor. Only remaining
+ # cases are non-prefetching eager / graph mode. We delegate to
+ # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- dataset_iterator = self._dataset.make_initializable_iterator()
+ # Eager mode generates already initialized iterators. Hence we cannot create
+ # an initializable iterator.
+ if context.executing_eagerly():
+ raise ValueError("Cannot create initializable iterator in Eager mode. "
+ "Please use `make_one_shot_iterator` instead.")
+ if self._prefetch_on_device:
+ dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ self._dataset, self._devices, source_device=self._source_device)
+ else:
+ dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator,
+ self._devices,
+ prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -813,7 +836,10 @@ class MultiWorkerDataset(object):
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+ worker_input,
+ worker_devices,
+ source_device=worker,
+ prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}