diff options
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r-- | tensorflow/python/training/slot_creator.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 258a6f045d..d76b22acd8 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): @@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: @@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name, prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype) |