aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/combinations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/combinations.py')
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 9a8ea4aa48..120349481f 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -46,6 +46,7 @@ import unittest
from absl.testing import parameterized
import six
+from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
from tensorflow.contrib.distribute.python import multi_worker_strategy
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
@@ -144,7 +145,7 @@ def _augment_with_special_arguments(test_method):
"""A wrapped test method that treats some arguments in a special way."""
mode = kwargs.pop("mode", "graph")
- distribution = kwargs.pop("distribution", None)
+ distribution = kwargs.get("distribution", None)
required_tpu = kwargs.pop("required_tpu", False)
required_gpus = kwargs.pop("required_gpus", None)
@@ -153,7 +154,6 @@ def _augment_with_special_arguments(test_method):
"Do not use `required_gpus` and `distribution` together.")
assert required_tpu is False, (
"Do not use `required_tpu` and `distribution` together.")
- kwargs["distribution"] = distribution.strategy
required_gpus = distribution.required_gpus
required_tpu = distribution.required_tpu
@@ -189,9 +189,13 @@ def _augment_with_special_arguments(test_method):
if mode == "eager":
with ops.Graph().as_default(), context.eager_mode():
+ if distribution:
+ kwargs_to_pass["distribution"] = distribution.strategy
test_method(**kwargs_to_pass)
elif mode == "graph":
with ops.Graph().as_default(), context.graph_mode():
+ if distribution:
+ kwargs_to_pass["distribution"] = distribution.strategy
test_method(**kwargs_to_pass)
else:
raise ValueError(
@@ -321,7 +325,9 @@ default_strategy = NamedDistribution(
one_device_strategy = NamedDistribution(
"OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
required_gpus=None)
-tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True)
+tpu_strategy = NamedDistribution(
+ "TPU", lambda: tpu_lib.TPUStrategy(TPUClusterResolver("")),
+ required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(