aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-03-19 13:38:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 13:46:30 -0700
commita78c5033e005f76b83df4fd97d0074fcc990f603 (patch)
tree06a64f64e135d32355a1c47694c20cfbff06c359
parenteb03b44049328404eb5578efda0729ca1a4f0a11 (diff)
TFE: Fix bug encountered when using `optimizer.apply_gradients` in a defun.
Prior to this change, `Optimizer` assumed that `not context.executing_eagerly()` implied that every variable that it was to update was constructed in a graph. That assumption is incorrect --- TensorFlow functions can mutate variables captured from or lifted into the eager context. As such, this change removes that assumption. Fixes #17792 PiperOrigin-RevId: 189633630
-rw-r--r--tensorflow/python/eager/function_test.py32
-rw-r--r--tensorflow/python/ops/variables.py6
-rw-r--r--tensorflow/python/training/optimizer.py11
3 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index b9cde16867..fd1d2c25ff 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import gradient_descent
class FunctionTest(test.TestCase):
@@ -762,6 +763,37 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(f().eval(), 4.0)
+ def testOptimizerInDefun(self):
+ def loss(v):
+ return v**2
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+
+ @function.defun
+ def train():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(v)
+ optimizer.apply_gradients(grad)
+ return v.read_value()
+
+ value = train()
+ self.assertEqual(value.numpy(), -1.0)
+
+ def testOptimizerInDefunWithCapturedVariable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+ def loss():
+ return v**2
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+
+ @function.defun
+ def train():
+ grad = backprop.implicit_grad(loss)()
+ optimizer.apply_gradients(grad)
+
+ train()
+ self.assertEqual(v.numpy(), -1.0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index c37cdd9e27..c646f79589 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -293,6 +293,7 @@ class Variable(checkpointable.CheckpointableBase):
Raises:
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
+ RuntimeError: If lifted into the eager context.
"""
_ = expected_shape
if initial_value is None:
@@ -319,6 +320,11 @@ class Variable(checkpointable.CheckpointableBase):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
+ # Ensure that we weren't lifted into the eager context.
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "tf.Variable not supported when eager execution is enabled. "
+ "Please use tf.contrib.eager.Variable instead")
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index af9cc3491c..bf79714f96 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -191,6 +191,10 @@ def _get_processor(v):
return _TensorProcessor(v)
else:
return _DenseResourceVariableProcessor(v)
+ if isinstance(
+ v, resource_variable_ops.ResourceVariable) and not v._in_graph_mode: # pylint: disable=protected-access
+ # True if and only if `v` was initialized eagerly.
+ return _DenseResourceVariableProcessor(v)
if v.op.type == "VarHandleOp":
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
@@ -546,7 +550,12 @@ class Optimizer(
# We colocate all ops created in _apply_dense or _apply_sparse
# on the same device as the variable.
# TODO(apassos): figure out how to get the variable name here.
- scope_name = "" if context.executing_eagerly() else var.op.name
+ if context.executing_eagerly() or isinstance(
+ var,
+ resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
+ scope_name = ""
+ else:
+ scope_name = var.op.name
with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
update_ops.append(processor.update_op(self, grad))
if global_step is None: