diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-23 14:40:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-23 14:47:23 -0700 |
commit | 4cd64cac16ccd22a2d956d6957ecfda0ff67ee89 (patch) | |
tree | 0529e0f43bf94f381e52306e5dcf13321134b459 /tensorflow/python/training/optimizer.py | |
parent | c319703d3669c9eec51f17ffc5e7e586b9608074 (diff) |
Make Optimizer.minimize work when eager execution is enabled.
PiperOrigin-RevId: 173172604
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r-- | tensorflow/python/training/optimizer.py | 49 |
1 files changed, 44 insertions, 5 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 86ba8e2c8e..82fc4edbcd 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -22,6 +22,7 @@ from __future__ import print_function import abc +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -335,6 +336,16 @@ class Optimizer(object): Raises: ValueError: If some of the variables are not `Variable` objects. + + @compatibility(eager): + When eager execution is enabled, `loss` should be a Python function that + takes elements of `var_list` as arguments and computes the value to be + minimized. If `var_list` is None, `loss` should take no arguments. + Minimization (and gradient computation) is done with respect to the + elements of `var_list` if not None, else with respect to any trainable + variables created during the execution of the `loss` function. + `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and + `grad_loss` are ignored when eager execution is enabled. """ grads_and_vars = self.compute_gradients( loss, var_list=var_list, gate_gradients=gate_gradients, @@ -385,7 +396,32 @@ class Optimizer(object): Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. + + @compatibility(eager): + When eager execution is enabled, `loss` should be a Python function that + takes elements of `var_list` as arguments and computes the value to be + minimized. If `var_list` is None, `loss` should take no arguments. + Gradient computation is done with respect to the elements of `var_list` if + not None, else with respect to any trainable variables created during the + execution of the `loss` function. + `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and + `grad_loss` are ignored when eager execution is enabled. """ + if context.in_eager_mode(): + if grad_loss is not None: + raise ValueError("`grad_loss` argument to Optimizer.compute_gradients " + "not supported when eager execution is enabled.") + if not callable(loss): + raise ValueError("`loss` passed to Optimizer.compute_gradients should " + "be a function when eager execution is enabled.") + # TODO(agarwal): consider passing parameters to the `loss` function. + if var_list is None: + return backprop.implicit_grad(loss)() + else: + var_list = nest.flatten(var_list) + grads = backprop.gradients_function(loss)(*var_list) + grads_and_vars = list(zip(grads, var_list)) + return grads_and_vars if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " @@ -489,11 +525,14 @@ class Optimizer(object): else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): - apply_updates = state_ops.assign_add(global_step, 1, name=name).op - - train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) - if apply_updates not in train_op: - train_op.append(apply_updates) + apply_updates = state_ops.assign_add(global_step, 1, name=name) + + if context.in_graph_mode(): + if isinstance(apply_updates, ops.Tensor): + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + train_op.append(apply_updates) return apply_updates |