diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-23 20:39:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-23 20:42:12 -0700 |
commit | 8f863f3d71542c47390f2d40348b72296ed5c4be (patch) | |
tree | 58adc3f4ea26c56c9416045981490b256da1c895 /tensorflow/contrib/layers | |
parent | 42e50daa384183d2f64e0ab5ae3f9bed07128e07 (diff) |
Add support for is_recompute optional kwarg to functions decorated with recompute_grad
PiperOrigin-RevId: 197834316
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 30 |
2 files changed, 49 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 8ed9f446bc..0e35b1aa8b 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): `variable_scope(name, use_resource=True), which are the default in Eager mode and when running on TPU. + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -482,6 +492,7 @@ def _is_on_tpu(): def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args for arg in args: if not isinstance(arg, framework_ops.Tensor): raise ValueError("All inputs to function must be Tensors") @@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: - outputs = fn(*args) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass @@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: - outputs = fn(*inputs) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 997f53b9e1..bc09ba8d43 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase): for grad in grads: self.assertTrue(grad is not None) + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) + if __name__ == "__main__": test.main() |