diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-09 09:16:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 09:23:49 -0700 |
commit | 62222398d89df1b4359658f339c9201b33eccf09 (patch) | |
tree | 4c8b3b7d3be3e0c3b835cdc0fc63e2e0f8002848 /tensorflow/contrib/layers | |
parent | deee57e92cf8e2278e613a026a516acafd6eddd1 (diff) |
Update recompute_grad to work for functions without variables.
Wrap recompute inputs in identity to avoid exponential graph traversal in gradients.
PiperOrigin-RevId: 203776272
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 20 |
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 0e35b1aa8b..dad3da3748 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): original_vars = set(tape.watched_variables()) # Backward pass - def grad_fn(*output_grads, **kwargs): + def _grad_fn(output_grads, variables=None): """Recompute outputs for gradient computation.""" - variables = [] + variables = variables or [] if original_vars: - variables = kwargs["variables"] - if set(variables) != original_vars: - raise ValueError(_WRONG_VARS_ERR) - del kwargs - inputs = list(args) + assert variables, ("Fn created variables but the variables were not " + "passed to the gradient fn.") + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + inputs = [array_ops.identity(x) for x in list(args)] # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: @@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + if not isinstance(outputs, (list, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, @@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): grad_vars = grads[len(inputs):] return grad_inputs, grad_vars + # custom_gradient inspects the signature of the function to determine + # whether the user expects variables passed in the grad_fn. If the function + # created variables, the grad_fn should accept the "variables" kwarg. + if original_vars: + def grad_fn(*output_grads, **kwargs): + return _grad_fn(output_grads, kwargs["variables"]) + else: + def grad_fn(*output_grads): + return _grad_fn(output_grads) + return outputs, grad_fn return fn_with_recompute(*args) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index bc09ba8d43..d5971fb9d8 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase): self.assertEqual(2, len(update_ops)) self.assertEqual([False, True], kwarg_values) + def testWithoutVariables(self): + + def concat_n(layer_list, num_inputs): + return math_ops.reduce_sum( + array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1), + axis=1, keepdims=True) + + @rev_block_lib.recompute_grad + def concat_n_wrap(*args): + return concat_n(args, 3) + + # DenseNet-style layers + layer_list = [random_ops.random_uniform((4, 8))] + for _ in range(5): + layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list))) + + grads = gradients_impl.gradients(layer_list[-1], layer_list[0]) + with self.test_session() as sess: + sess.run(grads) + if __name__ == "__main__": test.main() |