aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-23 14:26:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 14:29:33 -0700
commit3f30e6424fa3b8e890f9360d1661e61c2d1625a5 (patch)
tree1f93f484b27f1038bf4c697456f2933681bc419d
parenta0ee701f73cc80a56b41c2452006e166e0b835e6 (diff)
Makes gradients_function exception-safe.
PiperOrigin-RevId: 173170394
-rw-r--r--tensorflow/python/eager/backprop.py57
-rw-r--r--tensorflow/python/eager/backprop_test.py15
2 files changed, 40 insertions, 32 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 9580e84847..9d86ac77f8 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -335,15 +335,18 @@ def implicit_val_and_grad(f):
def grad_fn(*args):
"""Computes the gradient of the wrapped function."""
tape.push_new_tape()
- end_node = f(*args)
- variables = tape.top_tape_watched_variables()
+ try:
+ end_node = f(*args)
+ variables = tape.top_tape_watched_variables()
+ finally:
+ popped_tape = tape.pop_tape()
sources = [x.handle for x in variables]
if not sources:
raise ValueError("no trainable variables were accessed while the "
"function was being computed.")
grad = imperative_grad.imperative_grad(_default_vspace,
- tape.pop_tape(),
+ popped_tape,
nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -561,25 +564,12 @@ def val_and_grad_function(f, params=None):
def decorated(*args, **kwds):
"""Computes the value and gradient of the decorated function."""
- parameter_positions = _get_arg_spec(f, params, args)
dy = kwds.pop("dy", None)
- if dy is not None:
- dy = ops.convert_to_tensor(dy)
- assert not kwds, "The gradient function can't take keyword arguments."
- tape.push_new_tape()
- sources = []
- args = [
- ops.convert_to_tensor(args[i]) if i in parameter_positions else args[i]
- for i in range(len(args))
- ]
- args = _ensure_unique_tensor_objects(parameter_positions, args)
- for i in parameter_positions:
- sources.append(args[i])
- tape.watch(args[i])
- result = f(*args)
- return result, imperative_grad.imperative_grad(
- _default_vspace, tape.pop_tape(), nest.flatten(result), sources,
- output_gradients=nest.flatten(dy) if dy is not None else None)
+ if kwds:
+ raise ValueError("Functions to be differentiated cannot "
+ "receive keyword arguments.")
+ val, vjp = make_vjp(f, params)(*args, **kwds)
+ return val, vjp(dy=dy)
return decorated
@@ -619,17 +609,20 @@ def make_vjp(f, params=None):
parameter_positions = _get_arg_spec(f, params, args)
assert not kwds, "The gradient function can't take keyword arguments."
tape.push_new_tape()
- sources = []
- args = [
- ops.convert_to_tensor(args[i]) if i in parameter_positions else args[i]
- for i in range(len(args))
- ]
- args = _ensure_unique_tensor_objects(parameter_positions, args)
- for i in parameter_positions:
- sources.append(args[i])
- tape.watch(args[i])
- result = f(*args)
- t = tape.pop_tape()
+ try:
+ sources = []
+ args = [
+ ops.convert_to_tensor(args[i])
+ if i in parameter_positions else args[i]
+ for i in range(len(args))
+ ]
+ args = _ensure_unique_tensor_objects(parameter_positions, args)
+ for i in parameter_positions:
+ sources.append(args[i])
+ tape.watch(args[i])
+ result = f(*args)
+ finally:
+ t = tape.pop_tape()
def vjp(dy=None):
return imperative_grad.imperative_grad(
_default_vspace, t, nest.flatten(result), sources,
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 9ba5913c65..628f254b18 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -389,6 +389,21 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(f)
self.assertAllEqual(grad(1.0)[0], 2.0)
+ def testExceptionSafety(self):
+
+ def f(unused_x):
+ raise ValueError()
+
+ try:
+ backprop.gradients_function(f)(1.0)
+ except ValueError:
+ pass
+
+ def real_f(x):
+ return x * x
+
+ self.assertAllEqual(backprop.gradients_function(real_f)(1.0)[0], 2.0)
+
def testMultiValueConvertToTensor(self):
x = resource_variable_ops.ResourceVariable(
initial_value=array_ops.constant([1.0]), name='x')