diff options
author | Avijit <Avijit.Chakraborty@intel.com> | 2018-08-15 17:00:22 -0700 |
---|---|---|
committer | Avijit <Avijit.Chakraborty@intel.com> | 2018-08-15 17:00:22 -0700 |
commit | bc6be507c71046dfc889a90e3949a903d5d1e6eb (patch) | |
tree | 84557e7bb7798e3d418a619c8452aa7baf78f255 /tensorflow/python/training/slot_creator.py | |
parent | 9523a98466d16cf01fc76a67b489f1124cf626ac (diff) | |
parent | d2875ea71373d05c645587a83dd870fa8a0ec070 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r-- | tensorflow/python/training/slot_creator.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 258a6f045d..d76b22acd8 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): @@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: @@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name, prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype) |