aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-08-07 09:31:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 09:35:46 -0700
commit60fc3dba6722753b546f20295fd73a841410fafb (patch)
tree149936a0995eff43016b24399cb40f2c71444d03
parent7f666bb652063874134ed60b77edb4ddc85ec488 (diff)
Small fix to MultiDeviceIterator to allow for no prefetching if needed.
PiperOrigin-RevId: 207728361
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 0edd7c9fe9..0243c72c70 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -633,6 +633,15 @@ class MultiDeviceIterator(object):
devices,
prefetch_buffer_size=1,
source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+ """
self._dataset = dataset
self._devices = devices
self._source_device = source_device
@@ -673,7 +682,8 @@ class MultiDeviceIterator(object):
i, self._multi_device_iterator_resource, self._incarnation_id,
self._source_device_tensor, device, self._dataset.output_shapes,
self._dataset.output_types, self._dataset.output_classes)
- ds = ds.prefetch(prefetch_buffer_size)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
with ops.device(device):
self._device_iterators.append(ds.make_initializable_iterator())
i += 1