aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r--tensorflow/python/training/optimizer.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index fe9ffde11c..f75db08059 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -77,9 +77,10 @@ def _deduplicate_indexed_slices(values, indices):
def _var_key(var):
- if context.executing_eagerly():
- return var._unique_id # pylint: disable=protected-access
- return (var.op.graph, var.op.name)
+ # TODO(ashankar): Consolidate handling for eager and graph
+ if hasattr(var, "op"):
+ return (var.op.graph, var.op.name)
+ return var._unique_id # pylint: disable=protected-access
class _OptimizableVariable(object):
@@ -461,7 +462,8 @@ class Optimizer(
# Have to be careful to call distribute_lib.get_loss_reduction()
# *after* loss() is evaluated, so we know what loss reduction it uses.
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if distribute_lib.get_loss_reduction() == "mean":
+ if (distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN):
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -478,7 +480,8 @@ class Optimizer(
"be a function when eager execution is enabled.")
# Scale loss if using a "mean" loss reduction and multiple towers.
- if distribute_lib.get_loss_reduction() == "mean":
+ if (distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN):
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -649,7 +652,8 @@ class Optimizer(
towers. If `global_step` was not None, that operation also
increments `global_step`.
"""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
# Note that this is called in a cross-tower context.