diff options
author | 2017-10-23 14:26:02 -0700 | |
---|---|---|
committer | 2017-10-23 14:29:33 -0700 | |
commit | 3f30e6424fa3b8e890f9360d1661e61c2d1625a5 (patch) | |
tree | 1f93f484b27f1038bf4c697456f2933681bc419d | |
parent | a0ee701f73cc80a56b41c2452006e166e0b835e6 (diff) |
Makes gradients_function exception-safe.
PiperOrigin-RevId: 173170394
-rw-r--r-- | tensorflow/python/eager/backprop.py | 57 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 15 |
2 files changed, 40 insertions, 32 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 9580e84847..9d86ac77f8 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -335,15 +335,18 @@ def implicit_val_and_grad(f): def grad_fn(*args): """Computes the gradient of the wrapped function.""" tape.push_new_tape() - end_node = f(*args) - variables = tape.top_tape_watched_variables() + try: + end_node = f(*args) + variables = tape.top_tape_watched_variables() + finally: + popped_tape = tape.pop_tape() sources = [x.handle for x in variables] if not sources: raise ValueError("no trainable variables were accessed while the " "function was being computed.") grad = imperative_grad.imperative_grad(_default_vspace, - tape.pop_tape(), + popped_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -561,25 +564,12 @@ def val_and_grad_function(f, params=None): def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" - parameter_positions = _get_arg_spec(f, params, args) dy = kwds.pop("dy", None) - if dy is not None: - dy = ops.convert_to_tensor(dy) - assert not kwds, "The gradient function can't take keyword arguments." - tape.push_new_tape() - sources = [] - args = [ - ops.convert_to_tensor(args[i]) if i in parameter_positions else args[i] - for i in range(len(args)) - ] - args = _ensure_unique_tensor_objects(parameter_positions, args) - for i in parameter_positions: - sources.append(args[i]) - tape.watch(args[i]) - result = f(*args) - return result, imperative_grad.imperative_grad( - _default_vspace, tape.pop_tape(), nest.flatten(result), sources, - output_gradients=nest.flatten(dy) if dy is not None else None) + if kwds: + raise ValueError("Functions to be differentiated cannot " + "receive keyword arguments.") + val, vjp = make_vjp(f, params)(*args, **kwds) + return val, vjp(dy=dy) return decorated @@ -619,17 +609,20 @@ def make_vjp(f, params=None): parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." tape.push_new_tape() - sources = [] - args = [ - ops.convert_to_tensor(args[i]) if i in parameter_positions else args[i] - for i in range(len(args)) - ] - args = _ensure_unique_tensor_objects(parameter_positions, args) - for i in parameter_positions: - sources.append(args[i]) - tape.watch(args[i]) - result = f(*args) - t = tape.pop_tape() + try: + sources = [] + args = [ + ops.convert_to_tensor(args[i]) + if i in parameter_positions else args[i] + for i in range(len(args)) + ] + args = _ensure_unique_tensor_objects(parameter_positions, args) + for i in parameter_positions: + sources.append(args[i]) + tape.watch(args[i]) + result = f(*args) + finally: + t = tape.pop_tape() def vjp(dy=None): return imperative_grad.imperative_grad( _default_vspace, t, nest.flatten(result), sources, diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 9ba5913c65..628f254b18 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -389,6 +389,21 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(f) self.assertAllEqual(grad(1.0)[0], 2.0) + def testExceptionSafety(self): + + def f(unused_x): + raise ValueError() + + try: + backprop.gradients_function(f)(1.0) + except ValueError: + pass + + def real_f(x): + return x * x + + self.assertAllEqual(backprop.gradients_function(real_f)(1.0)[0], 2.0) + def testMultiValueConvertToTensor(self): x = resource_variable_ops.ResourceVariable( initial_value=array_ops.constant([1.0]), name='x') |