diff options
author | Priya Gupta <priyag@google.com> | 2018-04-18 12:03:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-18 12:06:14 -0700 |
commit | 03d18ae232c3cff4c56d1efec7bf29f9b16c4f68 (patch) | |
tree | b722f735de3f6ab2d3e8ef945baab452bd0b70ab /tensorflow/contrib/distribute/python/values.py | |
parent | 60444df318439654324ff797d66734c9920e48a2 (diff) |
Add support for initializable iterator in distribution strategies. Use that in estimator.
PiperOrigin-RevId: 193394603
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 87bf059038..18fedd2775 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.contrib.eager.python import datasets from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -510,6 +509,10 @@ class PerDeviceDataIterator(object): self._devices = devices self._prefetch_on_device = prefetch_on_device + @property + def initializer(self): + return self._iterator.initializer + def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: @@ -545,7 +548,8 @@ class PerDeviceDataset(object): "Prefetching is only supported in graph mode currently") if self._prefetch_on_device: - self._dataset = dataset + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. @@ -555,15 +559,13 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" - if self._prefetch_on_device: - on_device_dataset = self._dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - dataset_iterator = on_device_dataset.make_one_shot_iterator() - elif context.executing_eagerly(): - dataset_iterator = datasets.Iterator(self._dataset) - else: - dataset_iterator = self._dataset.make_one_shot_iterator() + dataset_iterator = self._dataset.make_one_shot_iterator() + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + def make_initializable_iterator(self): + """Get an initializable iterator for the distributed PerDeviceDataset.""" + dataset_iterator = self._dataset.make_initializable_iterator() return PerDeviceDataIterator( dataset_iterator, self._devices, self._prefetch_on_device) |