diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-31 16:30:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 16:49:36 -0700 |
commit | 145347ffe937dcde4efa7f3d5024d74d1bc6b17c (patch) | |
tree | 6620f786b349e64228eda3acc5911ffba2382823 /tensorflow/contrib/layers | |
parent | 53f907c682e86d409294fc7c12343376de637a3e (diff) |
Usability improvements to @recompute_grad
Error if fn closes over Tensor or Variable (not always detectable)
Allow None gradients to some inputs (filter out Nones before control_deps)
PiperOrigin-RevId: 211162615
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 127 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 10 |
2 files changed, 102 insertions, 35 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index b25f11b5a6..06da32072f 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -30,6 +30,7 @@ import functools import re import numpy as np +import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops @@ -44,6 +45,7 @@ from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -471,7 +473,8 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): Args: fn: a function that takes Tensors (all as positional arguments) and returns - a tuple of Tensors. + a tuple of Tensors. Note that `fn` should not close over any other + Tensors or Variables. use_data_dep: `bool`, if `True` will use a dummy data dependency to force the recompute to happen. If `False` will use a control dependency. By default will be `True` if in an XLA context and `False` otherwise. XLA @@ -485,7 +488,22 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): A wrapped fn that is identical to fn when called, but its activations will be discarded and recomputed on the backwards pass (i.e. on a call to tf.gradients). + + Raises: + ValueError: if `fn` closes over any Tensors or Variables. """ + # Check for closed-over Tensors/Variables + if fn.__code__.co_freevars: + closed_over_vars = dict(zip(fn.__code__.co_freevars, + [c.cell_contents for c in fn.__closure__])) + for var_name, value in six.iteritems(closed_over_vars): + if isinstance(value, (framework_ops.Tensor, variables_lib.Variable)): + raise ValueError( + "fn decorated with @recompute_grad closes over Tensor %s " + "(local variable name: %s). The decorated fn must not close over " + "Tensors or Variables because gradients will NOT be computed for " + "them through fn. To ensure correct gradients, make the " + "Tensor an input to fn." % (value.name, var_name)) @_safe_wraps(fn) def wrapped(*args): @@ -500,6 +518,62 @@ def _is_on_tpu(): return control_flow_util.GetContainingXLAContext(ctxt) is not None +def _recomputing_grad_fn(compute_fn, + original_args, + original_vars, + output_grads, + grad_fn_variables, + use_data_dep, + tupleize_grads, + arg_scope, + var_scope, + has_is_recompute_kwarg): + """Grad fn for recompute_grad.""" + variables = grad_fn_variables or [] + + # Identity ops around the inputs ensures correct gradient graph-walking. + inputs = [array_ops.identity(x) for x in list(original_args)] + + # Recompute outputs + # Use a control dependency to ensure that the recompute is not eliminated by + # CSE and that it happens on the backwards pass. + ctrl_dep_grads = [g for g in output_grads if g is not None] + with framework_ops.control_dependencies(ctrl_dep_grads): + if use_data_dep: + inputs = _force_data_dependency(output_grads, inputs) + # Re-enter scopes + with contrib_framework_ops.arg_scope(arg_scope): + with variable_scope.variable_scope(var_scope, reuse=True): + # Re-call the function and ensure that the touched variables are the + # same as in the first call. + with backprop.GradientTape() as tape: + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = compute_fn(*inputs, **fn_kwargs) + recompute_vars = set(tape.watched_variables()) + if original_vars != recompute_vars: + raise ValueError(_WRONG_VARS_ERR) + + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + outputs = list(outputs) + + # Compute gradients + grads = gradients_impl.gradients(outputs, inputs + variables, + output_grads) + + if tupleize_grads: + if use_data_dep: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args @@ -510,12 +584,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if use_data_dep_ == _USE_DEFAULT: use_data_dep_ = _is_on_tpu() + # Use custom_gradient and return a grad_fn that recomputes on the backwards + # pass. @custom_gradient.custom_gradient def fn_with_recompute(*args): """Wrapper for fn.""" - # Forward pass + # Capture the variable and arg scopes so we can re-enter them when + # recomputing. vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() + # Track all variables touched in the function. with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: @@ -523,46 +601,25 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) - # Backward pass def _grad_fn(output_grads, variables=None): - """Recompute outputs for gradient computation.""" - variables = variables or [] + # Validate that custom_gradient passes the right variables into grad_fn. if original_vars: assert variables, ("Fn created variables but the variables were not " "passed to the gradient fn.") if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) - inputs = [array_ops.identity(x) for x in list(args)] - # Recompute outputs - with framework_ops.control_dependencies(output_grads): - if use_data_dep_: - inputs = _force_data_dependency(output_grads, inputs) - with contrib_framework_ops.arg_scope(arg_scope): - with variable_scope.variable_scope(vs, reuse=True): - with backprop.GradientTape() as tape: - fn_kwargs = {} - if has_is_recompute_kwarg: - fn_kwargs["is_recomputing"] = True - outputs = fn(*inputs, **fn_kwargs) - recompute_vars = set(tape.watched_variables()) - if original_vars != recompute_vars: - raise ValueError(_WRONG_VARS_ERR) - - if not isinstance(outputs, (list, tuple)): - outputs = [outputs] - outputs = list(outputs) - grads = gradients_impl.gradients(outputs, inputs + variables, - output_grads) - - if tupleize_grads: - if use_data_dep_: - grads = _tuple_with_data_dep(grads) - else: - grads = control_flow_ops.tuple(grads) - grad_inputs = grads[:len(inputs)] - grad_vars = grads[len(inputs):] - return grad_inputs, grad_vars + return _recomputing_grad_fn( + compute_fn=fn, + original_args=args, + original_vars=original_vars, + output_grads=output_grads, + grad_fn_variables=variables, + use_data_dep=use_data_dep_, + tupleize_grads=tupleize_grads, + arg_scope=arg_scope, + var_scope=vs, + has_is_recompute_kwarg=has_is_recompute_kwarg) # custom_gradient inspects the signature of the function to determine # whether the user expects variables passed in the grad_fn. If the function diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index d5971fb9d8..c34b5a8017 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -392,6 +392,16 @@ class RecomputeTest(test.TestCase): with self.test_session() as sess: sess.run(grads) + def testErrorOnClosedOverTensor(self): + x = random_ops.random_uniform((4, 8)) + y = random_ops.random_uniform((4, 8)) + z = x * y + + with self.assertRaisesWithPredicateMatch(ValueError, "closes over"): + @rev_block_lib.recompute_grad + def fn_with_capture(a): # pylint: disable=unused-variable + return a * z + if __name__ == "__main__": test.main() |