aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-16 12:04:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 12:07:19 -0700
commit6f76e5c985246484c95ce48776685ebe3d40f74f (patch)
treed16880435c286c87694c7019443e786e12794d6c
parent01ed446a17f4b6cb3a9e4ffd2c39641b9e96ed96 (diff)
Turn off MirroredStrategy Dataset prefetching in tests when using the
combinations library. It adds some small non-determinism to the input batches which can make tests flaky. Also add a default DistributionStrategy combination. PiperOrigin-RevId: 196866569
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py19
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py8
3 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 340ffaee58..6c5c49d777 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -151,6 +151,7 @@ py_library(
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 45d191127e..d719234cf6 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -51,6 +51,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adam
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import tf_inspect
@@ -262,25 +263,31 @@ class NamedDistribution(object):
return self._required_tpu
+default_strategy = NamedDistribution(
+ "Default",
+ distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ required_gpus=None)
one_device_strategy = NamedDistribution(
"OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
- None)
+ required_gpus=None)
tpu_strategy_single_iteration = NamedDistribution(
"TPUSingleIteration",
tpu_strategy.TPUStrategy(iterations_per_step=1),
required_tpu=True)
tpu_strategy = NamedDistribution(
"TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
+# Note that we disable prefetching for testing since prefetching makes
+# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
- mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1)
-mirrored_strategy_without_prefetch = NamedDistribution(
- "MirroredCPUAndGPUNoPrefetch",
mirrored_strategy.MirroredStrategy(
- ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1)
+ ["/gpu:0", "/cpu:0"], prefetch_on_device=False),
+ required_gpus=1)
mirrored_strategy_with_two_gpus = NamedDistribution(
"Mirrored2GPUs",
- mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2)
+ mirrored_strategy.MirroredStrategy(
+ ["/gpu:0", "/gpu:1"], prefetch_on_device=False),
+ required_gpus=2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d2054715f1..5c056a7c73 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -207,11 +207,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
renorm=renorm,
update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)
- # Disable prefetching since that makes the specific input on each device
- # to be non deterministic, and this test relies on specific input being
- # on each device.
+ # Make sure prefetching is disabled since that makes the
+ # specific input on each device to be non deterministic, and
+ # this test relies on specific input being on each device.
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
- distribution._prefetch_on_device = False
+ self.assertFalse(distribution._prefetch_on_device)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()