diff options
author | Priya Gupta <priyag@google.com> | 2018-08-14 11:22:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 11:31:36 -0700 |
commit | 77fabbeabb5b9061d8c606050c1ea79aec990c03 (patch) | |
tree | 1495d6acb396eebd40c703b891a4f2e7437a8532 /tensorflow/contrib/optimizer_v2 | |
parent | cea262e16a004d73295259c42f21e2655da3df13 (diff) |
1. Move distribution strategy context utility methods to a separate file with few dependencies. This allows us to import this in some places without creating circular dependencies as the original file imported many things.
2. Move the stack used in distribution strategy context to the graph. This allows us to use different strategies in different graphs (for e.g. in train and eval).
This fixes #21412 and #21180.
PiperOrigin-RevId: 208680454
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 8c11d8bcfd..f6ecaba834 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Map from graph_key to state for that graph. We use the graph_key # since it works in both eager and graph mode, and gives the outer # graph inside functions. - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # In a cross-tower context for a DistributionStrategy, which means # only one Optimizer will be created, not one per tower. @@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= 1. / num_towers @@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= 1. / num_towers @@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer): if not filtered: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, filtered, global_step=global_step, name=name) def _get_or_create_state(self, var_list=None): |