aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 21:37:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 21:40:04 -0700
commitfedfc47ca6713adbbf82e10d4803c5fe94234bbd (patch)
treee850bb2086549706092770a683f5e5667db407b5 /tensorflow
parent76801dda9b4766d729ab88267ee47f48d05eafb7 (diff)
Resolve device names when passed into DistributionStrategy methods.
PiperOrigin-RevId: 199241723
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
-rw-r--r--tensorflow/contrib/distribute/python/values.py7
3 files changed, 22 insertions, 20 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index e400fa5be2..98e7228f24 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -46,9 +46,9 @@ import unittest
from absl.testing import parameterized
import six
-from tensorflow.contrib.distribute.python import mirrored_strategy
-from tensorflow.contrib.distribute.python import one_device_strategy
-from tensorflow.contrib.distribute.python import tpu_strategy
+from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
+from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
+from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
@@ -289,9 +289,9 @@ class NamedObject(object):
class NamedDistribution(object):
"""Translates DistributionStrategy and its data into a good name."""
- def __init__(self, name, distribution, required_gpus=None,
+ def __init__(self, name, distribution_fn, required_gpus=None,
required_tpu=False):
- self._distribution = distribution
+ self._distribution_fn = distribution_fn
self._name = name
self._required_gpus = required_gpus
self._required_tpu = required_tpu
@@ -301,7 +301,7 @@ class NamedDistribution(object):
@property
def strategy(self):
- return self._distribution
+ return self._distribution_fn()
@property
def required_gpus(self):
@@ -312,29 +312,29 @@ class NamedDistribution(object):
return self._required_tpu
+# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
- "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
+ "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
required_gpus=None)
tpu_strategy_single_iteration = NamedDistribution(
"TPUSingleIteration",
- tpu_strategy.TPUStrategy(iterations_per_step=1),
+ lambda: tpu_lib.TPUStrategy(iterations_per_step=1),
required_tpu=True)
-tpu_strategy = NamedDistribution(
- "TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
+tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/cpu:0"], prefetch_on_device=False),
required_gpus=1)
mirrored_strategy_with_two_gpus = NamedDistribution(
"Mirrored2GPUs",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 14dbbd6e27..6eadba976b 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -84,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument.")
# TODO(josh11b): Require at least 2 devices?
- self._devices = devices
- self._canonical_device_set = set(
- [device_util.canonicalize(d) for d in devices])
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
dict((d, i) for i, d in enumerate(devices)))
self._cross_tower_ops = cross_tower_ops
@@ -400,7 +399,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return list(colocate_with._index.keys())
elif isinstance(colocate_with, six.string_types):
- return [colocate_with]
+ return [device_util.resolve(colocate_with)]
+ elif isinstance(colocate_with, list):
+ return [device_util.resolve(d) for d in colocate_with]
else:
return colocate_with
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 49b4e24daa..9572ade8e4 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -65,9 +65,10 @@ class DistributedValues(object):
device = device_util.canonicalize(device)
try:
return self._index[device]
- except KeyError:
- raise ValueError("Device %s not found in %s (current device %s)" %
- (device, self._index.keys(), device_util.current()))
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
def on_device(self, device):
device = device_util.canonicalize(device)