aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/parameter_server_strategy_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index cf29c0ed91..02eb68227d 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
@@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
last_part_device = 'device:CPU:0'
else:
last_part_device = (
- 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ 'device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
@@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
tower_compute_device = '/device:CPU:0'
else:
tower_compute_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_compute_device = device_util.canonicalize(tower_compute_device)
if 'CPU' in variable_device:
tower_variable_device = '/device:CPU:0'
else:
tower_variable_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_variable_device = device_util.canonicalize(tower_variable_device)
a = constant_op.constant(1.0)