aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 0edd7c9fe9..0243c72c70 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -633,6 +633,15 @@ class MultiDeviceIterator(object):
devices,
prefetch_buffer_size=1,
source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+ """
self._dataset = dataset
self._devices = devices
self._source_device = source_device
@@ -673,7 +682,8 @@ class MultiDeviceIterator(object):
i, self._multi_device_iterator_resource, self._incarnation_id,
self._source_device_tensor, device, self._dataset.output_shapes,
self._dataset.output_types, self._dataset.output_classes)
- ds = ds.prefetch(prefetch_buffer_size)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
with ops.device(device):
self._device_iterators.append(ds.make_initializable_iterator())
i += 1