aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute_test.py')
-rw-r--r--tensorflow/python/training/distribute_test.py53
1 files changed, 32 insertions, 21 deletions
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 694145ede7..f03bd39100 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
class _TestTowerContext(distribute.TowerContext):
@@ -49,12 +50,12 @@ class _TestStrategy(distribute.DistributionStrategy):
def _assert_in_default_state(t):
- t.assertIs(distribute._default_tower_context,
- distribute.get_tower_context())
- t.assertIs(None, distribute.get_cross_tower_context())
- t.assertIs(distribute._default_distribution_strategy,
- distribute.get_distribution_strategy())
- t.assertFalse(distribute.has_distribution_strategy())
+ t.assertIs(distribution_strategy_context._get_default_tower_context(),
+ distribution_strategy_context.get_tower_context())
+ t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
+ t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
+ distribution_strategy_context.get_distribution_strategy())
+ t.assertFalse(distribution_strategy_context.has_distribution_strategy())
class TestStrategyTest(test.TestCase):
@@ -64,11 +65,13 @@ class TestStrategyTest(test.TestCase):
dist = _TestStrategy()
def run_fn():
- tower_context = distribute.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
self.assertTrue(tower_context is not None)
- self.assertIs(None, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
@@ -86,10 +89,12 @@ class TestStrategyTest(test.TestCase):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
@@ -120,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase):
_assert_in_default_state(self)
def merge_fn(dist, s):
- self.assertIs(distribute._default_distribution_strategy, dist)
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertFalse(distribute.has_distribution_strategy())
+ self.assertIs(
+ distribution_strategy_context._get_default_distribution_strategy(),
+ dist)
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
+ self.assertFalse(
+ distribution_strategy_context.has_distribution_strategy())
return "foo_" + s
- tower_ctx = distribute.get_tower_context()
- self.assertIs(distribute._default_tower_context, tower_ctx)
+ tower_ctx = distribution_strategy_context.get_tower_context()
+ self.assertIs(distribution_strategy_context._get_default_tower_context(),
+ tower_ctx)
self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
_assert_in_default_state(self)