aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 15:34:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 15:36:40 -0700
commiteee15c1f8ea56dbb516fa9e35392e0a224e99966 (patch)
treef99c1d58b1cc21f8c13749a7173640d39c4a541b /tensorflow/contrib/layers
parent3a00d79b16348f0a53379e81b8e98bdd93d4833e (diff)
Update recompute_grad for TPU
PiperOrigin-RevId: 190536468
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py105
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py61
2 files changed, 146 insertions, 20 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 123275e1fd..0b38c0c3fd 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -29,6 +29,7 @@ from __future__ import print_function
import functools
import re
+import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -37,6 +38,7 @@ from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
@@ -46,6 +48,7 @@ from tensorflow.python.util import nest
__all__ = ["rev_block", "RevBlock", "recompute_grad"]
LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*")
+_USE_DEFAULT = "__rev_block_lib_default"
def _acc_grads(*lists_of_grads):
@@ -219,7 +222,13 @@ class RevBlock(base.Layer):
def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
"""Custom gradient fn for a block of reversible residual layers."""
+ # Inputs have passed through an Identity. Recover the original Tensors to
+ # be able to match up side inputs.
+ assert [u"Identity"] == list(set([x.op.type for x in inputs]))
+ inputs = [x.op.inputs[0] for x in inputs]
side_inputs = inputs[2:]
+ del inputs
+
f_side_idxs = [None] * len(self.f_side_input)
g_side_idxs = [None] * len(self.g_side_input)
assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
@@ -405,12 +414,36 @@ def rev_block(x1,
return block.forward(x1, x2)
-def recompute_grad(fn):
+def enable_with_args(dec):
+ """A decorator for decorators to enable their usage with or without args."""
+
+ @functools.wraps(dec)
+ def new_dec(*args, **kwargs):
+ if len(args) == 1 and not kwargs and callable(args[0]):
+ # Used as decorator without args
+ fn = args[0]
+ return dec(fn)
+ else:
+ return lambda fn: dec(fn, *args, **kwargs)
+
+ return new_dec
+
+
+@enable_with_args
+def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""Decorator that recomputes the function on the backwards pass.
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
+ use_data_dep: `bool`, if `True` will use a dummy data dependency to force
+ the recompute to happen. If `False` will use a control dependency. By
+ default will be `True` if in an XLA context and `False` otherwise. XLA
+ ignores control dependencies and so this data dependency is necessary.
+ tupleize_grads: `bool`, if `True` will use control dependencies to ensure
+ that all gradients are produced before any are consumed by downstream ops.
+ If `use_data_dep` is also `True`, will use a data dependency instead of
+ a control dependency.
Returns:
A wrapped fn that is identical to fn when called, but its activations will
@@ -420,13 +453,25 @@ def recompute_grad(fn):
@functools.wraps(fn)
def wrapped(*args):
- return _recompute_grad(fn, args)
+ return _recompute_grad(
+ fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
return wrapped
-def _recompute_grad(fn, args):
+def _is_on_tpu():
+ ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ return control_flow_util.GetContainingXLAContext(ctxt) is not None
+
+
+def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
+ for arg in args:
+ if not isinstance(arg, framework_ops.Tensor):
+ raise ValueError("All inputs to function must be Tensors")
+ use_data_dep_ = use_data_dep
+ if use_data_dep_ == _USE_DEFAULT:
+ use_data_dep_ = _is_on_tpu()
cached_vs = []
cached_arg_scope = []
@@ -436,6 +481,8 @@ def _recompute_grad(fn, args):
del outputs
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
+ if use_data_dep_:
+ inputs = _force_data_dependency(output_grads, inputs)
with contrib_framework_ops.arg_scope(cached_arg_scope[0]):
with variable_scope.variable_scope(cached_vs[0], reuse=True):
outputs = fn(*inputs)
@@ -444,6 +491,13 @@ def _recompute_grad(fn, args):
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables, output_grads)
+
+ if tupleize_grads:
+ if use_data_dep_:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
grad_inputs = grads[:len(inputs)]
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
@@ -532,7 +586,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
get_vars_fn = (
vs.global_variables if use_global_vars else vs.trainable_variables)
len_before_vars = len(get_vars_fn())
- inputs = list(inputs)
+ inputs = [array_ops.identity(x) for x in inputs]
outputs = fn(*inputs)
train_vars = get_vars_fn()[len_before_vars:]
@@ -581,3 +635,46 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
flat_inputs = nest.flatten(defun_inputs)
id_out = identity(*flat_inputs)
return id_out
+
+
+def _force_data_dependency(first_compute, then_compute):
+ """Force all of `then_compute` to depend on all of `first_compute`.
+
+ Uses a dummy data dependency, which is useful when running on TPUs because
+ XLA ignores control dependencies. Only supports float arguments.
+
+ Args:
+ first_compute: `list<Tensor>`. These will be made to run before the
+ `Tensor`s `then_compute`.
+ then_compute: `list<Tensor>`. These will run after all the `Tensor`s in
+ `first_compute`.
+
+ Returns:
+ `list<Tensor>`, same length as `then_compute`.
+
+ Raises:
+ ValueError: if ranks are unknown or types are not floating.
+ """
+
+ def _first_element(x):
+ if x.get_shape().ndims is None:
+ raise ValueError("Rank of Tensor %s must be known" % x)
+ ndims = x.get_shape().ndims
+ return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+
+ first_compute_sum = math_ops.add_n(
+ [_first_element(x) for x in first_compute if x is not None])
+ dtype = first_compute_sum.dtype
+ if not dtype.is_floating:
+ raise ValueError("_force_data_dependency only supports floating dtypes.")
+ epsilon = np.finfo(dtype.as_numpy_dtype).tiny
+ zero = array_ops.stop_gradient(epsilon * first_compute_sum)
+
+ return [
+ array_ops.identity(x) + zero if x is not None else None
+ for x in then_compute
+ ]
+
+
+def _tuple_with_data_dep(tensors):
+ return _force_data_dependency(tensors, tensors)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index cbcbcd7511..d1ad4e8c98 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase):
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
for g1, g2 in zip(gd_val, g_val):
- self.assertAllClose(g1, g2)
+ self.assertAllClose(g1, g2, rtol=1e-5)
def testRevBlock(self):
self._testRevBlock()
@@ -255,25 +255,54 @@ class RecomputeTest(test.TestCase):
def fn_recompute(x):
return fn(x)
+ @rev_block_lib.recompute_grad(use_data_dep=True)
+ def fn_use_data_dep(x):
+ return fn(x)
+
+ @rev_block_lib.recompute_grad(tupleize_grads=True)
+ def fn_tupleize(x):
+ return fn(x)
+
+ @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True)
+ def fn_both(x):
+ return fn(x)
+
x = random_ops.random_uniform((3, 1, 3))
- recompute_vars = None
- with variable_scope.variable_scope("recompute") as vs:
- out1 = math_ops.reduce_sum(fn_recompute(x))
- recompute_vars = vs.trainable_variables()
- reg_vars = None
- with variable_scope.variable_scope("regular") as vs:
- out2 = math_ops.reduce_sum(fn(x))
- reg_vars = vs.trainable_variables()
-
- grad1 = gradients_impl.gradients(out1, recompute_vars)
- grad2 = gradients_impl.gradients(out2, reg_vars)
+
+ names_and_fns = [
+ ("recompute", fn_recompute),
+ ("regular", fn),
+ ("use_data_dep", fn_use_data_dep),
+ ("tupleize", fn_tupleize),
+ ("tuple_and_data_dep", fn_both),
+ ]
+ outputs_and_vars = []
+ for name, wrapped_fn in names_and_fns:
+ with variable_scope.variable_scope(name) as vs:
+ out = math_ops.reduce_sum(wrapped_fn(x))
+ outputs_and_vars.append((out, vs.trainable_variables()))
+
+ all_grads = []
+ for out, scope_vars in outputs_and_vars:
+ all_grads.append(gradients_impl.gradients(out, scope_vars))
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- outs = sess.run([out1, out2, grad1, grad2])
- self.assertAllClose(outs[0], outs[1])
- for g1, g2 in zip(outs[2], outs[3]):
- self.assertAllClose(g1, g2)
+ outputs = list(zip(*outputs_and_vars))[0]
+ outs, all_grads_val = sess.run([outputs, all_grads])
+
+ # All outputs are the same
+ current = outs[0]
+ for out in outs[1:]:
+ self.assertAllClose(current, out)
+ current = out
+
+ # All gradients are the same
+ for grads in zip(all_grads_val):
+ current = grads[0]
+ for g in grads[1:]:
+ self.assertAllClose(current, g)
+ current = g
class FnWithCustomGradTest(test.TestCase):