aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-23 14:40:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 14:47:23 -0700
commit4cd64cac16ccd22a2d956d6957ecfda0ff67ee89 (patch)
tree0529e0f43bf94f381e52306e5dcf13321134b459 /tensorflow/python/training/optimizer.py
parentc319703d3669c9eec51f17ffc5e7e586b9608074 (diff)
Make Optimizer.minimize work when eager execution is enabled.
PiperOrigin-RevId: 173172604
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r--tensorflow/python/training/optimizer.py49
1 files changed, 44 insertions, 5 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 86ba8e2c8e..82fc4edbcd 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import abc
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -335,6 +336,16 @@ class Optimizer(object):
Raises:
ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager):
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Minimization (and gradient computation) is done with respect to the
+ elements of `var_list` if not None, else with respect to any trainable
+ variables created during the execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
"""
grads_and_vars = self.compute_gradients(
loss, var_list=var_list, gate_gradients=gate_gradients,
@@ -385,7 +396,32 @@ class Optimizer(object):
Raises:
TypeError: If `var_list` contains anything else than `Variable` objects.
ValueError: If some arguments are invalid.
+
+ @compatibility(eager):
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Gradient computation is done with respect to the elements of `var_list` if
+ not None, else with respect to any trainable variables created during the
+ execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
"""
+ if context.in_eager_mode():
+ if grad_loss is not None:
+ raise ValueError("`grad_loss` argument to Optimizer.compute_gradients "
+ "not supported when eager execution is enabled.")
+ if not callable(loss):
+ raise ValueError("`loss` passed to Optimizer.compute_gradients should "
+ "be a function when eager execution is enabled.")
+ # TODO(agarwal): consider passing parameters to the `loss` function.
+ if var_list is None:
+ return backprop.implicit_grad(loss)()
+ else:
+ var_list = nest.flatten(var_list)
+ grads = backprop.gradients_function(loss)(*var_list)
+ grads_and_vars = list(zip(grads, var_list))
+ return grads_and_vars
if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
Optimizer.GATE_GRAPH]:
raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
@@ -489,11 +525,14 @@ class Optimizer(object):
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
- apply_updates = state_ops.assign_add(global_step, 1, name=name).op
-
- train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
- if apply_updates not in train_op:
- train_op.append(apply_updates)
+ apply_updates = state_ops.assign_add(global_step, 1, name=name)
+
+ if context.in_graph_mode():
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
return apply_updates