aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index a066adf124..5db2fff239 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -24,7 +24,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase):
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):