aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/slot_creator.py
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-15 17:00:22 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-15 17:00:22 -0700
commitbc6be507c71046dfc889a90e3949a903d5d1e6eb (patch)
tree84557e7bb7798e3d418a619c8452aa7baf78f255 /tensorflow/python/training/slot_creator.py
parent9523a98466d16cf01fc76a67b489f1124cf626ac (diff)
parentd2875ea71373d05c645587a83dd870fa8a0ec070 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r--tensorflow/python/training/slot_creator.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 258a6f045d..d76b22acd8 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
@@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
@@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)