aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/slot_creator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r--tensorflow/python/training/slot_creator.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 258a6f045d..d76b22acd8 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
@@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
@@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)