aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-31 16:30:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 16:49:36 -0700
commit145347ffe937dcde4efa7f3d5024d74d1bc6b17c (patch)
tree6620f786b349e64228eda3acc5911ffba2382823 /tensorflow/contrib/layers
parent53f907c682e86d409294fc7c12343376de637a3e (diff)
Usability improvements to @recompute_grad
Error if fn closes over Tensor or Variable (not always detectable) Allow None gradients to some inputs (filter out Nones before control_deps) PiperOrigin-RevId: 211162615
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py127
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py10
2 files changed, 102 insertions, 35 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index b25f11b5a6..06da32072f 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -30,6 +30,7 @@ import functools
import re
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -44,6 +45,7 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -471,7 +473,8 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
- a tuple of Tensors.
+ a tuple of Tensors. Note that `fn` should not close over any other
+ Tensors or Variables.
use_data_dep: `bool`, if `True` will use a dummy data dependency to force
the recompute to happen. If `False` will use a control dependency. By
default will be `True` if in an XLA context and `False` otherwise. XLA
@@ -485,7 +488,22 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
+
+ Raises:
+ ValueError: if `fn` closes over any Tensors or Variables.
"""
+ # Check for closed-over Tensors/Variables
+ if fn.__code__.co_freevars:
+ closed_over_vars = dict(zip(fn.__code__.co_freevars,
+ [c.cell_contents for c in fn.__closure__]))
+ for var_name, value in six.iteritems(closed_over_vars):
+ if isinstance(value, (framework_ops.Tensor, variables_lib.Variable)):
+ raise ValueError(
+ "fn decorated with @recompute_grad closes over Tensor %s "
+ "(local variable name: %s). The decorated fn must not close over "
+ "Tensors or Variables because gradients will NOT be computed for "
+ "them through fn. To ensure correct gradients, make the "
+ "Tensor an input to fn." % (value.name, var_name))
@_safe_wraps(fn)
def wrapped(*args):
@@ -500,6 +518,62 @@ def _is_on_tpu():
return control_flow_util.GetContainingXLAContext(ctxt) is not None
+def _recomputing_grad_fn(compute_fn,
+ original_args,
+ original_vars,
+ output_grads,
+ grad_fn_variables,
+ use_data_dep,
+ tupleize_grads,
+ arg_scope,
+ var_scope,
+ has_is_recompute_kwarg):
+ """Grad fn for recompute_grad."""
+ variables = grad_fn_variables or []
+
+ # Identity ops around the inputs ensures correct gradient graph-walking.
+ inputs = [array_ops.identity(x) for x in list(original_args)]
+
+ # Recompute outputs
+ # Use a control dependency to ensure that the recompute is not eliminated by
+ # CSE and that it happens on the backwards pass.
+ ctrl_dep_grads = [g for g in output_grads if g is not None]
+ with framework_ops.control_dependencies(ctrl_dep_grads):
+ if use_data_dep:
+ inputs = _force_data_dependency(output_grads, inputs)
+ # Re-enter scopes
+ with contrib_framework_ops.arg_scope(arg_scope):
+ with variable_scope.variable_scope(var_scope, reuse=True):
+ # Re-call the function and ensure that the touched variables are the
+ # same as in the first call.
+ with backprop.GradientTape() as tape:
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = compute_fn(*inputs, **fn_kwargs)
+ recompute_vars = set(tape.watched_variables())
+ if original_vars != recompute_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs = list(outputs)
+
+ # Compute gradients
+ grads = gradients_impl.gradients(outputs, inputs + variables,
+ output_grads)
+
+ if tupleize_grads:
+ if use_data_dep:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
+ grad_inputs = grads[:len(inputs)]
+ grad_vars = grads[len(inputs):]
+ return grad_inputs, grad_vars
+
+
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
@@ -510,12 +584,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if use_data_dep_ == _USE_DEFAULT:
use_data_dep_ = _is_on_tpu()
+ # Use custom_gradient and return a grad_fn that recomputes on the backwards
+ # pass.
@custom_gradient.custom_gradient
def fn_with_recompute(*args):
"""Wrapper for fn."""
- # Forward pass
+ # Capture the variable and arg scopes so we can re-enter them when
+ # recomputing.
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
+ # Track all variables touched in the function.
with backprop.GradientTape() as tape:
fn_kwargs = {}
if has_is_recompute_kwarg:
@@ -523,46 +601,25 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
- # Backward pass
def _grad_fn(output_grads, variables=None):
- """Recompute outputs for gradient computation."""
- variables = variables or []
+ # Validate that custom_gradient passes the right variables into grad_fn.
if original_vars:
assert variables, ("Fn created variables but the variables were not "
"passed to the gradient fn.")
if set(variables) != original_vars:
raise ValueError(_WRONG_VARS_ERR)
- inputs = [array_ops.identity(x) for x in list(args)]
- # Recompute outputs
- with framework_ops.control_dependencies(output_grads):
- if use_data_dep_:
- inputs = _force_data_dependency(output_grads, inputs)
- with contrib_framework_ops.arg_scope(arg_scope):
- with variable_scope.variable_scope(vs, reuse=True):
- with backprop.GradientTape() as tape:
- fn_kwargs = {}
- if has_is_recompute_kwarg:
- fn_kwargs["is_recomputing"] = True
- outputs = fn(*inputs, **fn_kwargs)
- recompute_vars = set(tape.watched_variables())
- if original_vars != recompute_vars:
- raise ValueError(_WRONG_VARS_ERR)
-
- if not isinstance(outputs, (list, tuple)):
- outputs = [outputs]
- outputs = list(outputs)
- grads = gradients_impl.gradients(outputs, inputs + variables,
- output_grads)
-
- if tupleize_grads:
- if use_data_dep_:
- grads = _tuple_with_data_dep(grads)
- else:
- grads = control_flow_ops.tuple(grads)
- grad_inputs = grads[:len(inputs)]
- grad_vars = grads[len(inputs):]
- return grad_inputs, grad_vars
+ return _recomputing_grad_fn(
+ compute_fn=fn,
+ original_args=args,
+ original_vars=original_vars,
+ output_grads=output_grads,
+ grad_fn_variables=variables,
+ use_data_dep=use_data_dep_,
+ tupleize_grads=tupleize_grads,
+ arg_scope=arg_scope,
+ var_scope=vs,
+ has_is_recompute_kwarg=has_is_recompute_kwarg)
# custom_gradient inspects the signature of the function to determine
# whether the user expects variables passed in the grad_fn. If the function
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index d5971fb9d8..c34b5a8017 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -392,6 +392,16 @@ class RecomputeTest(test.TestCase):
with self.test_session() as sess:
sess.run(grads)
+ def testErrorOnClosedOverTensor(self):
+ x = random_ops.random_uniform((4, 8))
+ y = random_ops.random_uniform((4, 8))
+ z = x * y
+
+ with self.assertRaisesWithPredicateMatch(ValueError, "closes over"):
+ @rev_block_lib.recompute_grad
+ def fn_with_capture(a): # pylint: disable=unused-variable
+ return a * z
+
if __name__ == "__main__":
test.main()