aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-04-18 12:03:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 12:06:14 -0700
commit03d18ae232c3cff4c56d1efec7bf29f9b16c4f68 (patch)
treeb722f735de3f6ab2d3e8ef945baab452bd0b70ab /tensorflow/contrib/distribute/python/values.py
parent60444df318439654324ff797d66734c9920e48a2 (diff)
Add support for initializable iterator in distribution strategies. Use that in estimator.
PiperOrigin-RevId: 193394603
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 87bf059038..18fedd2775 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -28,7 +28,6 @@ import six
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import prefetching_ops_v2
-from tensorflow.contrib.eager.python import datasets
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -510,6 +509,10 @@ class PerDeviceDataIterator(object):
self._devices = devices
self._prefetch_on_device = prefetch_on_device
+ @property
+ def initializer(self):
+ return self._iterator.initializer
+
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
@@ -545,7 +548,8 @@ class PerDeviceDataset(object):
"Prefetching is only supported in graph mode currently")
if self._prefetch_on_device:
- self._dataset = dataset
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
@@ -555,15 +559,13 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- if self._prefetch_on_device:
- on_device_dataset = self._dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(self._devices))
- dataset_iterator = on_device_dataset.make_one_shot_iterator()
- elif context.executing_eagerly():
- dataset_iterator = datasets.Iterator(self._dataset)
- else:
- dataset_iterator = self._dataset.make_one_shot_iterator()
+ dataset_iterator = self._dataset.make_one_shot_iterator()
+ return PerDeviceDataIterator(
+ dataset_iterator, self._devices, self._prefetch_on_device)
+ def make_initializable_iterator(self):
+ """Get an initializable iterator for the distributed PerDeviceDataset."""
+ dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
dataset_iterator, self._devices, self._prefetch_on_device)