diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-27 20:59:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 21:02:52 -0700 |
commit | 97cba0b88cb3ce6a3f3cc66a8c4fd414bd3ac1a8 (patch) | |
tree | f7ce9c26ebb93cbef7aacc3f725ba5a304470407 | |
parent | 370d385c3029a7972ba201c8303942b30f09521c (diff) |
Allowing source_device to be set to /cpu:0 for multi device iterator in distribution strategies. That is always the appropriate option.
In the existing code, we would set it to a partially specified "worker" name that was ambiguous and end up on the GPU.
PiperOrigin-RevId: 214882658
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 5 |
2 files changed, 2 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 945f450387..504f45a695 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -482,8 +482,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device, - source_device=device_util.resolve("/device:CPU:0")) + self._prefetch_on_device) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a0cd029f51..cce41e7717 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -708,10 +708,8 @@ class PerDeviceDataset(object): dataset, devices, prefetch_on_device=None, - source_device="/cpu:0", ): self._devices = devices - self._source_device = source_device if source_device is not None else "/cpu:0" # Default to using prefetching in graph mode, unless specified. # TODO(rohanj): Enable prefetching in eager mode. @@ -750,7 +748,7 @@ class PerDeviceDataset(object): "Please use `make_one_shot_iterator` instead.") if self._prefetch_on_device: dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices, source_device=self._source_device) + self._dataset, self._devices) else: dataset_iterator = self._dataset.make_initializable_iterator() return PerDeviceDataIterator( @@ -838,7 +836,6 @@ class MultiWorkerDataset(object): self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, - source_device=worker, prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): |