diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-27 20:59:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 21:02:52 -0700 |
commit | 97cba0b88cb3ce6a3f3cc66a8c4fd414bd3ac1a8 (patch) | |
tree | f7ce9c26ebb93cbef7aacc3f725ba5a304470407 /tensorflow/contrib/distribute/python/values.py | |
parent | 370d385c3029a7972ba201c8303942b30f09521c (diff) |
Allowing source_device to be set to /cpu:0 for multi device iterator in distribution strategies. That is always the appropriate option.
In the existing code, we would set it to a partially specified "worker" name that was ambiguous and end up on the GPU.
PiperOrigin-RevId: 214882658
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 5 |
1 files changed, 1 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a0cd029f51..cce41e7717 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -708,10 +708,8 @@ class PerDeviceDataset(object): dataset, devices, prefetch_on_device=None, - source_device="/cpu:0", ): self._devices = devices - self._source_device = source_device if source_device is not None else "/cpu:0" # Default to using prefetching in graph mode, unless specified. # TODO(rohanj): Enable prefetching in eager mode. @@ -750,7 +748,7 @@ class PerDeviceDataset(object): "Please use `make_one_shot_iterator` instead.") if self._prefetch_on_device: dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices, source_device=self._source_device) + self._dataset, self._devices) else: dataset_iterator = self._dataset.make_initializable_iterator() return PerDeviceDataIterator( @@ -838,7 +836,6 @@ class MultiWorkerDataset(object): self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, - source_device=worker, prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): |