aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-27 20:59:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 21:02:52 -0700
commit97cba0b88cb3ce6a3f3cc66a8c4fd414bd3ac1a8 (patch)
treef7ce9c26ebb93cbef7aacc3f725ba5a304470407 /tensorflow/contrib/distribute/python/values.py
parent370d385c3029a7972ba201c8303942b30f09521c (diff)
Allowing source_device to be set to /cpu:0 for multi device iterator in distribution strategies. That is always the appropriate option.
In the existing code, we would set it to a partially specified "worker" name that was ambiguous and end up on the GPU. PiperOrigin-RevId: 214882658
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a0cd029f51..cce41e7717 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -708,10 +708,8 @@ class PerDeviceDataset(object):
dataset,
devices,
prefetch_on_device=None,
- source_device="/cpu:0",
):
self._devices = devices
- self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
# TODO(rohanj): Enable prefetching in eager mode.
@@ -750,7 +748,7 @@ class PerDeviceDataset(object):
"Please use `make_one_shot_iterator` instead.")
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices, source_device=self._source_device)
+ self._dataset, self._devices)
else:
dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
@@ -838,7 +836,6 @@ class MultiWorkerDataset(object):
self._datasets[worker] = PerDeviceDataset(
worker_input,
worker_devices,
- source_device=worker,
prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):