diff options
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r-- | tensorflow/python/training/optimizer.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 6d95b144d5..1b6bce2865 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -464,7 +465,8 @@ class Optimizer( # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) @@ -482,7 +484,8 @@ class Optimizer( # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= (1. / num_towers) @@ -548,15 +551,15 @@ class Optimizer( # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-tower context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, grads_and_vars, global_step, name) # No DistributionStrategy case. @@ -799,7 +802,8 @@ class Optimizer( v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( |