aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 20:39:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 20:42:12 -0700
commit8f863f3d71542c47390f2d40348b72296ed5c4be (patch)
tree58adc3f4ea26c56c9416045981490b256da1c895 /tensorflow/contrib/layers
parent42e50daa384183d2f64e0ab5ae3f9bed07128e07 (diff)
Add support for is_recompute optional kwarg to functions decorated with recompute_grad
PiperOrigin-RevId: 197834316
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py21
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py30
2 files changed, 49 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 8ed9f446bc..0e35b1aa8b 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
__all__ = ["rev_block", "RevBlock", "recompute_grad"]
@@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
`variable_scope(name, use_resource=True), which are the default in Eager mode
and when running on TPU.
+ Warning: Because the function will be called again on the backwards pass, the
+ user should be careful to not use ops in their function that mutate state or
+ have randomness (for example, batch normalization or dropout). If the function
+ does have such operations, it is recommended that the function take the
+ `is_recomputing` keyword argument which will be `False` on the forward pass
+ and `True` on the backwards pass so that it can disable state changes when
+ `is_recomputing=True` (for example, not updating the moving averages in batch
+ normalization).
+
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
@@ -482,6 +492,7 @@ def _is_on_tpu():
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
+ has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
for arg in args:
if not isinstance(arg, framework_ops.Tensor):
raise ValueError("All inputs to function must be Tensors")
@@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
with backprop.GradientTape() as tape:
- outputs = fn(*args)
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = False
+ outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
# Backward pass
@@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
with contrib_framework_ops.arg_scope(arg_scope):
with variable_scope.variable_scope(vs, reuse=True):
with backprop.GradientTape() as tape:
- outputs = fn(*inputs)
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = fn(*inputs, **fn_kwargs)
recompute_vars = set(tape.watched_variables())
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index 997f53b9e1..bc09ba8d43 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.layers import rev_block_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import convolutional
from tensorflow.python.layers import core as core_layers
+from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase):
for grad in grads:
self.assertTrue(grad is not None)
+ def testWithIsRecomputeKwarg(self):
+
+ kwarg_values = []
+
+ @rev_block_lib.recompute_grad
+ def layer_with_recompute(inputs, is_recomputing=False):
+ kwarg_values.append(is_recomputing)
+ out = core_layers.dense(inputs, 2)
+ out = normalization_layers.batch_normalization(out, training=True)
+ if is_recomputing:
+ # Ensure that the updates are not duplicated by popping off the latest
+ # 2 additions.
+ update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS)
+ update_ops.pop()
+ update_ops.pop()
+ return out
+
+ x = array_ops.ones((2, 4), dtypes.float32)
+ with variable_scope.variable_scope("layer1", use_resource=True):
+ y = layer_with_recompute(x)
+ loss = math_ops.reduce_sum(y)
+ tvars = variables.trainable_variables()
+ gradients_impl.gradients(loss, [x] + tvars)
+
+ update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ self.assertEqual(2, len(update_ops))
+ self.assertEqual([False, True], kwarg_values)
+
if __name__ == "__main__":
test.main()