aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r--tensorflow/python/training/optimizer.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 6d95b144d5..1b6bce2865 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -464,7 +465,8 @@ class Optimizer(
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -482,7 +484,8 @@ class Optimizer(
# Scale loss if using a "mean" loss reduction and multiple towers.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -548,15 +551,15 @@ class Optimizer(
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# Handle DistributionStrategy case.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-tower context.")
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, grads_and_vars, global_step, name)
# No DistributionStrategy case.
@@ -799,7 +802,8 @@ class Optimizer(
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(