aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/rev_block_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/rev_block_lib.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py26
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0e35b1aa8b..dad3da3748 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
original_vars = set(tape.watched_variables())
# Backward pass
- def grad_fn(*output_grads, **kwargs):
+ def _grad_fn(output_grads, variables=None):
"""Recompute outputs for gradient computation."""
- variables = []
+ variables = variables or []
if original_vars:
- variables = kwargs["variables"]
- if set(variables) != original_vars:
- raise ValueError(_WRONG_VARS_ERR)
- del kwargs
- inputs = list(args)
+ assert variables, ("Fn created variables but the variables were not "
+ "passed to the gradient fn.")
+ if set(variables) != original_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+ inputs = [array_ops.identity(x) for x in list(args)]
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
if use_data_dep_:
@@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
- if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
+ if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables,
@@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
+ # custom_gradient inspects the signature of the function to determine
+ # whether the user expects variables passed in the grad_fn. If the function
+ # created variables, the grad_fn should accept the "variables" kwarg.
+ if original_vars:
+ def grad_fn(*output_grads, **kwargs):
+ return _grad_fn(output_grads, kwargs["variables"])
+ else:
+ def grad_fn(*output_grads):
+ return _grad_fn(output_grads)
+
return outputs, grad_fn
return fn_with_recompute(*args)