diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-22 16:15:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 16:34:16 -0700 |
commit | 8c6623137c928e20cd2b54471b06582fa118ad9a (patch) | |
tree | b97806d292aea9be9b8651005cdc231f7bc92ef0 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | |
parent | 05dfd7079d8e3d4487effc6848e4f323eca9ba37 (diff) |
Implemented the configure method and properties needed by distribute coordinator in MirroredStrategy.
PiperOrigin-RevId: 209848375
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9a4cc0a897..612655a38a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import sys from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 @@ -41,6 +42,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] @@ -1244,5 +1246,21 @@ class MirroredStrategyDefunTest(test.TestCase): self._call_and_check(fn1, [factors], expected_result, [fn1]) +class MultiWorkerMirroredStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + cluster_spec = server_lib.ClusterSpec({ + "worker": ["/job:worker/task:0", "/job:worker/task:1"] + }) + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy.configure(cluster_spec=cluster_spec) + return strategy + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + if __name__ == "__main__": test.main() |