aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-22 16:15:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 16:34:16 -0700
commit8c6623137c928e20cd2b54471b06582fa118ad9a (patch)
treeb97806d292aea9be9b8651005cdc231f7bc92ef0 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parent05dfd7079d8e3d4487effc6848e4f323eca9ba37 (diff)
Implemented the configure method and properties needed by distribute coordinator in MirroredStrategy.
PiperOrigin-RevId: 209848375
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9a4cc0a897..612655a38a 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import sys
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
@@ -41,6 +42,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -1244,5 +1246,21 @@ class MirroredStrategyDefunTest(test.TestCase):
self._call_and_check(fn1, [factors], expected_result, [fn1])
+class MultiWorkerMirroredStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "worker": ["/job:worker/task:0", "/job:worker/task:1"]
+ })
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=cluster_spec)
+ return strategy
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+
if __name__ == "__main__":
test.main()