aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/combinations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/combinations.py')
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index e400fa5be2..98e7228f24 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -46,9 +46,9 @@ import unittest
from absl.testing import parameterized
import six
-from tensorflow.contrib.distribute.python import mirrored_strategy
-from tensorflow.contrib.distribute.python import one_device_strategy
-from tensorflow.contrib.distribute.python import tpu_strategy
+from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
+from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
+from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
@@ -289,9 +289,9 @@ class NamedObject(object):
class NamedDistribution(object):
"""Translates DistributionStrategy and its data into a good name."""
- def __init__(self, name, distribution, required_gpus=None,
+ def __init__(self, name, distribution_fn, required_gpus=None,
required_tpu=False):
- self._distribution = distribution
+ self._distribution_fn = distribution_fn
self._name = name
self._required_gpus = required_gpus
self._required_tpu = required_tpu
@@ -301,7 +301,7 @@ class NamedDistribution(object):
@property
def strategy(self):
- return self._distribution
+ return self._distribution_fn()
@property
def required_gpus(self):
@@ -312,29 +312,29 @@ class NamedDistribution(object):
return self._required_tpu
+# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
- "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
+ "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
required_gpus=None)
tpu_strategy_single_iteration = NamedDistribution(
"TPUSingleIteration",
- tpu_strategy.TPUStrategy(iterations_per_step=1),
+ lambda: tpu_lib.TPUStrategy(iterations_per_step=1),
required_tpu=True)
-tpu_strategy = NamedDistribution(
- "TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
+tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/cpu:0"], prefetch_on_device=False),
required_gpus=1)
mirrored_strategy_with_two_gpus = NamedDistribution(
"Mirrored2GPUs",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)