diff options
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r-- | tensorflow/python/training/optimizer.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index fe9ffde11c..f75db08059 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -77,9 +77,10 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - if context.executing_eagerly(): - return var._unique_id # pylint: disable=protected-access - return (var.op.graph, var.op.name) + # TODO(ashankar): Consolidate handling for eager and graph + if hasattr(var, "op"): + return (var.op.graph, var.op.name) + return var._unique_id # pylint: disable=protected-access class _OptimizableVariable(object): @@ -461,7 +462,8 @@ class Optimizer( # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. - if distribute_lib.get_loss_reduction() == "mean": + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss_value *= (1. / num_towers) @@ -478,7 +480,8 @@ class Optimizer( "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. - if distribute_lib.get_loss_reduction() == "mean": + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss *= (1. / num_towers) @@ -649,7 +652,8 @@ class Optimizer( towers. If `global_step` was not None, that operation also increments `global_step`. """ - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) # Note that this is called in a cross-tower context. |