aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-14 11:22:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 11:31:36 -0700
commit77fabbeabb5b9061d8c606050c1ea79aec990c03 (patch)
tree1495d6acb396eebd40c703b891a4f2e7437a8532 /tensorflow/contrib/optimizer_v2
parentcea262e16a004d73295259c42f21e2655da3df13 (diff)
1. Move distribution strategy context utility methods to a separate file with few dependencies. This allows us to import this in some places without creating circular dependencies as the original file imported many things.
2. Move the stack used in distribution strategy context to the graph. This allows us to use different strategies in different graphs (for e.g. in train and eval). This fixes #21412 and #21180. PiperOrigin-RevId: 208680454
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 8c11d8bcfd..f6ecaba834 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Map from graph_key to state for that graph. We use the graph_key
# since it works in both eager and graph mode, and gives the outer
# graph inside functions.
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None:
# In a cross-tower context for a DistributionStrategy, which means
# only one Optimizer will be created, not one per tower.
@@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= 1. / num_towers
@@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= 1. / num_towers
@@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):