diff options
Diffstat (limited to 'tensorflow/python/training/distribute_test.py')
-rw-r--r-- | tensorflow/python/training/distribute_test.py | 53 |
1 files changed, 32 insertions, 21 deletions
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py index 694145ede7..f03bd39100 100644 --- a/tensorflow/python/training/distribute_test.py +++ b/tensorflow/python/training/distribute_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context class _TestTowerContext(distribute.TowerContext): @@ -49,12 +50,12 @@ class _TestStrategy(distribute.DistributionStrategy): def _assert_in_default_state(t): - t.assertIs(distribute._default_tower_context, - distribute.get_tower_context()) - t.assertIs(None, distribute.get_cross_tower_context()) - t.assertIs(distribute._default_distribution_strategy, - distribute.get_distribution_strategy()) - t.assertFalse(distribute.has_distribution_strategy()) + t.assertIs(distribution_strategy_context._get_default_tower_context(), + distribution_strategy_context.get_tower_context()) + t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) + t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), + distribution_strategy_context.get_distribution_strategy()) + t.assertFalse(distribution_strategy_context.has_distribution_strategy()) class TestStrategyTest(test.TestCase): @@ -64,11 +65,13 @@ class TestStrategyTest(test.TestCase): dist = _TestStrategy() def run_fn(): - tower_context = distribute.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() self.assertTrue(tower_context is not None) - self.assertIs(None, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertIs(None, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, @@ -86,10 +89,12 @@ class TestStrategyTest(test.TestCase): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) @@ -120,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase): _assert_in_default_state(self) def merge_fn(dist, s): - self.assertIs(distribute._default_distribution_strategy, dist) - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertIs(dist, distribute.get_distribution_strategy()) - self.assertFalse(distribute.has_distribution_strategy()) + self.assertIs( + distribution_strategy_context._get_default_distribution_strategy(), + dist) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) + self.assertFalse( + distribution_strategy_context.has_distribution_strategy()) return "foo_" + s - tower_ctx = distribute.get_tower_context() - self.assertIs(distribute._default_tower_context, tower_ctx) + tower_ctx = distribution_strategy_context.get_tower_context() + self.assertIs(distribution_strategy_context._get_default_tower_context(), + tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self) |