diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/strategy_test_lib.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/strategy_test_lib.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index baed0ebaae..371b97ba96 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -28,7 +28,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,7 +45,8 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_tower() call, calls a # get_tower_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + distribution_strategy_context.get_tower_context().merge_call( + _raise_exception_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -58,7 +59,7 @@ def _call_raises_fn(dist): # calls a get_tower_context().merge_call() that calls a # call_for_each_tower() that raises an exception. def _merge_call_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_raises_fn) + distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist): # get_tower_context().merge_call() that calls a call_for_each_tower() that # calls a get_tower_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + distribution_strategy_context.get_tower_context().merge_call( + _call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase): expected_devices = [False] * len(d.worker_devices) def mark_devices_fn(): - tower_id = distribute_lib.get_tower_context().tower_id + tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True |