aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/optimizer_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/optimizer_v2.py')
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 8c11d8bcfd..f6ecaba834 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Map from graph_key to state for that graph. We use the graph_key
# since it works in both eager and graph mode, and gives the outer
# graph inside functions.
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None:
# In a cross-tower context for a DistributionStrategy, which means
# only one Optimizer will be created, not one per tower.
@@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= 1. / num_towers
@@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= 1. / num_towers
@@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):