diff options
author | 2018-06-04 21:37:43 -0700 | |
---|---|---|
committer | 2018-06-04 21:40:04 -0700 | |
commit | fedfc47ca6713adbbf82e10d4803c5fe94234bbd (patch) | |
tree | e850bb2086549706092770a683f5e5667db407b5 /tensorflow | |
parent | 76801dda9b4766d729ab88267ee47f48d05eafb7 (diff) |
Resolve device names when passed into DistributionStrategy methods.
PiperOrigin-RevId: 199241723
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/distribute/python/combinations.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 7 |
3 files changed, 22 insertions, 20 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index e400fa5be2..98e7228f24 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,9 +46,9 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context @@ -289,9 +289,9 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -301,7 +301,7 @@ class NamedDistribution(object): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -312,29 +312,29 @@ class NamedDistribution(object): return self._required_tpu +# pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy_single_iteration = NamedDistribution( "TPUSingleIteration", - tpu_strategy.TPUStrategy(iterations_per_step=1), + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), required_tpu=True) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/cpu:0"], prefetch_on_device=False), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 14dbbd6e27..6eadba976b 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -84,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops @@ -400,7 +399,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 49b4e24daa..9572ade8e4 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -65,9 +65,10 @@ class DistributedValues(object): device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) |