aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/strategy_test_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/strategy_test_lib.py')
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index baed0ebaae..371b97ba96 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -28,7 +28,7 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer
@@ -45,7 +45,8 @@ def _raise_exception_fn(_=None):
# Must be the argument to a distribution.call_for_each_tower() call, calls a
# get_tower_context().merge_call() that raises an exception.
def _merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_raise_exception_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _raise_exception_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -58,7 +59,7 @@ def _call_raises_fn(dist):
# calls a get_tower_context().merge_call() that calls a
# call_for_each_tower() that raises an exception.
def _merge_call_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist):
# get_tower_context().merge_call() that calls a call_for_each_tower() that
# calls a get_tower_context().merge_call() that raises an exception.
def _merge_call_merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _call_merge_raises_fn)
class DistributionTestBase(test.TestCase):
@@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase):
expected_devices = [False] * len(d.worker_devices)
def mark_devices_fn():
- tower_id = distribute_lib.get_tower_context().tower_id
+ tower_id = distribution_strategy_context.get_tower_context().tower_id
self.assertLess(tower_id, len(d.worker_devices))
self.assertFalse(expected_devices[tower_id])
expected_devices[tower_id] = True