aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 21:37:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 21:40:04 -0700
commitfedfc47ca6713adbbf82e10d4803c5fe94234bbd (patch)
treee850bb2086549706092770a683f5e5667db407b5 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent76801dda9b4766d729ab88267ee47f48d05eafb7 (diff)
Resolve device names when passed into DistributionStrategy methods.
PiperOrigin-RevId: 199241723
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 14dbbd6e27..6eadba976b 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -84,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument.")
# TODO(josh11b): Require at least 2 devices?
- self._devices = devices
- self._canonical_device_set = set(
- [device_util.canonicalize(d) for d in devices])
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
dict((d, i) for i, d in enumerate(devices)))
self._cross_tower_ops = cross_tower_ops
@@ -400,7 +399,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return list(colocate_with._index.keys())
elif isinstance(colocate_with, six.string_types):
- return [colocate_with]
+ return [device_util.resolve(colocate_with)]
+ elif isinstance(colocate_with, list):
+ return [device_util.resolve(d) for d in colocate_with]
else:
return colocate_with