aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-09 09:16:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 09:23:49 -0700
commit62222398d89df1b4359658f339c9201b33eccf09 (patch)
tree4c8b3b7d3be3e0c3b835cdc0fc63e2e0f8002848 /tensorflow/contrib/layers
parentdeee57e92cf8e2278e613a026a516acafd6eddd1 (diff)
Update recompute_grad to work for functions without variables.
Wrap recompute inputs in identity to avoid exponential graph traversal in gradients. PiperOrigin-RevId: 203776272
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py26
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0e35b1aa8b..dad3da3748 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
original_vars = set(tape.watched_variables())
# Backward pass
- def grad_fn(*output_grads, **kwargs):
+ def _grad_fn(output_grads, variables=None):
"""Recompute outputs for gradient computation."""
- variables = []
+ variables = variables or []
if original_vars:
- variables = kwargs["variables"]
- if set(variables) != original_vars:
- raise ValueError(_WRONG_VARS_ERR)
- del kwargs
- inputs = list(args)
+ assert variables, ("Fn created variables but the variables were not "
+ "passed to the gradient fn.")
+ if set(variables) != original_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+ inputs = [array_ops.identity(x) for x in list(args)]
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
if use_data_dep_:
@@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
- if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
+ if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables,
@@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
+ # custom_gradient inspects the signature of the function to determine
+ # whether the user expects variables passed in the grad_fn. If the function
+ # created variables, the grad_fn should accept the "variables" kwarg.
+ if original_vars:
+ def grad_fn(*output_grads, **kwargs):
+ return _grad_fn(output_grads, kwargs["variables"])
+ else:
+ def grad_fn(*output_grads):
+ return _grad_fn(output_grads)
+
return outputs, grad_fn
return fn_with_recompute(*args)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index bc09ba8d43..d5971fb9d8 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase):
self.assertEqual(2, len(update_ops))
self.assertEqual([False, True], kwarg_values)
+ def testWithoutVariables(self):
+
+ def concat_n(layer_list, num_inputs):
+ return math_ops.reduce_sum(
+ array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1),
+ axis=1, keepdims=True)
+
+ @rev_block_lib.recompute_grad
+ def concat_n_wrap(*args):
+ return concat_n(args, 3)
+
+ # DenseNet-style layers
+ layer_list = [random_ops.random_uniform((4, 8))]
+ for _ in range(5):
+ layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
+
+ grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
+ with self.test_session() as sess:
+ sess.run(grads)
+
if __name__ == "__main__":
test.main()