diff options
author | 2018-03-19 13:38:23 -0700 | |
---|---|---|
committer | 2018-03-19 13:46:30 -0700 | |
commit | a78c5033e005f76b83df4fd97d0074fcc990f603 (patch) | |
tree | 06a64f64e135d32355a1c47694c20cfbff06c359 | |
parent | eb03b44049328404eb5578efda0729ca1a4f0a11 (diff) |
TFE: Fix bug encountered when using `optimizer.apply_gradients` in a defun.
Prior to this change, `Optimizer` assumed that `not
context.executing_eagerly()` implied that every variable that it was to update
was constructed in a graph. That assumption is incorrect --- TensorFlow
functions can mutate variables captured from or lifted into the eager context. As such, this change removes that assumption.
Fixes #17792
PiperOrigin-RevId: 189633630
-rw-r--r-- | tensorflow/python/eager/function_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 6 | ||||
-rw-r--r-- | tensorflow/python/training/optimizer.py | 11 |
3 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index b9cde16867..fd1d2c25ff 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import gradient_descent class FunctionTest(test.TestCase): @@ -762,6 +763,37 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(f().eval(), 4.0) + def testOptimizerInDefun(self): + def loss(v): + return v**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + @function.defun + def train(): + v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(v) + optimizer.apply_gradients(grad) + return v.read_value() + + value = train() + self.assertEqual(value.numpy(), -1.0) + + def testOptimizerInDefunWithCapturedVariable(self): + v = resource_variable_ops.ResourceVariable(1.0) + def loss(): + return v**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + @function.defun + def train(): + grad = backprop.implicit_grad(loss)() + optimizer.apply_gradients(grad) + + train() + self.assertEqual(v.numpy(), -1.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index c37cdd9e27..c646f79589 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -293,6 +293,7 @@ class Variable(checkpointable.CheckpointableBase): Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. + RuntimeError: If lifted into the eager context. """ _ = expected_shape if initial_value is None: @@ -319,6 +320,11 @@ class Variable(checkpointable.CheckpointableBase): if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): + # Ensure that we weren't lifted into the eager context. + if context.executing_eagerly(): + raise RuntimeError( + "tf.Variable not supported when eager execution is enabled. " + "Please use tf.contrib.eager.Variable instead") with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index af9cc3491c..bf79714f96 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -191,6 +191,10 @@ def _get_processor(v): return _TensorProcessor(v) else: return _DenseResourceVariableProcessor(v) + if isinstance( + v, resource_variable_ops.ResourceVariable) and not v._in_graph_mode: # pylint: disable=protected-access + # True if and only if `v` was initialized eagerly. + return _DenseResourceVariableProcessor(v) if v.op.type == "VarHandleOp": return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): @@ -546,7 +550,12 @@ class Optimizer( # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. - scope_name = "" if context.executing_eagerly() else var.op.name + if context.executing_eagerly() or isinstance( + var, + resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access + scope_name = "" + else: + scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, grad)) if global_step is None: |