diff options
-rw-r--r-- | tensorflow/contrib/data/python/ops/prefetching_ops.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 0edd7c9fe9..0243c72c70 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -633,6 +633,15 @@ class MultiDeviceIterator(object): devices, prefetch_buffer_size=1, source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ self._dataset = dataset self._devices = devices self._source_device = source_device @@ -673,7 +682,8 @@ class MultiDeviceIterator(object): i, self._multi_device_iterator_resource, self._incarnation_id, self._source_device_tensor, device, self._dataset.output_shapes, self._dataset.output_types, self._dataset.output_classes) - ds = ds.prefetch(prefetch_buffer_size) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) with ops.device(device): self._device_iterators.append(ds.make_initializable_iterator()) i += 1 |