diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/parameter_server_strategy_test.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/parameter_server_strategy_test.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index cf29c0ed91..02eb68227d 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, @@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, last_part_device = 'device:CPU:0' else: last_part_device = ( - 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + 'device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, tower_compute_device = '/device:CPU:0' else: tower_compute_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_compute_device = device_util.canonicalize(tower_compute_device) if 'CPU' in variable_device: tower_variable_device = '/device:CPU:0' else: tower_variable_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_variable_device = device_util.canonicalize(tower_variable_device) a = constant_op.constant(1.0) |