diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-26 15:34:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 15:36:40 -0700 |
commit | eee15c1f8ea56dbb516fa9e35392e0a224e99966 (patch) | |
tree | f99c1d58b1cc21f8c13749a7173640d39c4a541b /tensorflow/contrib/layers | |
parent | 3a00d79b16348f0a53379e81b8e98bdd93d4833e (diff) |
Update recompute_grad for TPU
PiperOrigin-RevId: 190536468
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 105 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 61 |
2 files changed, 146 insertions, 20 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 123275e1fd..0b38c0c3fd 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -29,6 +29,7 @@ from __future__ import print_function import functools import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops @@ -37,6 +38,7 @@ from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -46,6 +48,7 @@ from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") +_USE_DEFAULT = "__rev_block_lib_default" def _acc_grads(*lists_of_grads): @@ -219,7 +222,13 @@ class RevBlock(base.Layer): def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" + # Inputs have passed through an Identity. Recover the original Tensors to + # be able to match up side inputs. + assert [u"Identity"] == list(set([x.op.type for x in inputs])) + inputs = [x.op.inputs[0] for x in inputs] side_inputs = inputs[2:] + del inputs + f_side_idxs = [None] * len(self.f_side_input) g_side_idxs = [None] * len(self.g_side_input) assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) @@ -405,12 +414,36 @@ def rev_block(x1, return block.forward(x1, x2) -def recompute_grad(fn): +def enable_with_args(dec): + """A decorator for decorators to enable their usage with or without args.""" + + @functools.wraps(dec) + def new_dec(*args, **kwargs): + if len(args) == 1 and not kwargs and callable(args[0]): + # Used as decorator without args + fn = args[0] + return dec(fn) + else: + return lambda fn: dec(fn, *args, **kwargs) + + return new_dec + + +@enable_with_args +def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. + use_data_dep: `bool`, if `True` will use a dummy data dependency to force + the recompute to happen. If `False` will use a control dependency. By + default will be `True` if in an XLA context and `False` otherwise. XLA + ignores control dependencies and so this data dependency is necessary. + tupleize_grads: `bool`, if `True` will use control dependencies to ensure + that all gradients are produced before any are consumed by downstream ops. + If `use_data_dep` is also `True`, will use a data dependency instead of + a control dependency. Returns: A wrapped fn that is identical to fn when called, but its activations will @@ -420,13 +453,25 @@ def recompute_grad(fn): @functools.wraps(fn) def wrapped(*args): - return _recompute_grad(fn, args) + return _recompute_grad( + fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads) return wrapped -def _recompute_grad(fn, args): +def _is_on_tpu(): + ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(ctxt) is not None + + +def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + for arg in args: + if not isinstance(arg, framework_ops.Tensor): + raise ValueError("All inputs to function must be Tensors") + use_data_dep_ = use_data_dep + if use_data_dep_ == _USE_DEFAULT: + use_data_dep_ = _is_on_tpu() cached_vs = [] cached_arg_scope = [] @@ -436,6 +481,8 @@ def _recompute_grad(fn, args): del outputs # Recompute outputs with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(cached_arg_scope[0]): with variable_scope.variable_scope(cached_vs[0], reuse=True): outputs = fn(*inputs) @@ -444,6 +491,13 @@ def _recompute_grad(fn, args): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars @@ -532,7 +586,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): get_vars_fn = ( vs.global_variables if use_global_vars else vs.trainable_variables) len_before_vars = len(get_vars_fn()) - inputs = list(inputs) + inputs = [array_ops.identity(x) for x in inputs] outputs = fn(*inputs) train_vars = get_vars_fn()[len_before_vars:] @@ -581,3 +635,46 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): flat_inputs = nest.flatten(defun_inputs) id_out = identity(*flat_inputs) return id_out + + +def _force_data_dependency(first_compute, then_compute): + """Force all of `then_compute` to depend on all of `first_compute`. + + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. + + Args: + first_compute: `list<Tensor>`. These will be made to run before the + `Tensor`s `then_compute`. + then_compute: `list<Tensor>`. These will run after all the `Tensor`s in + `first_compute`. + + Returns: + `list<Tensor>`, same length as `then_compute`. + + Raises: + ValueError: if ranks are unknown or types are not floating. + """ + + def _first_element(x): + if x.get_shape().ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.get_shape().ndims + return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), []) + + first_compute_sum = math_ops.add_n( + [_first_element(x) for x in first_compute if x is not None]) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + epsilon = np.finfo(dtype.as_numpy_dtype).tiny + zero = array_ops.stop_gradient(epsilon * first_compute_sum) + + return [ + array_ops.identity(x) + zero if x is not None else None + for x in then_compute + ] + + +def _tuple_with_data_dep(tensors): + return _force_data_dependency(tensors, tensors) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index cbcbcd7511..d1ad4e8c98 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase): y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): - self.assertAllClose(g1, g2) + self.assertAllClose(g1, g2, rtol=1e-5) def testRevBlock(self): self._testRevBlock() @@ -255,25 +255,54 @@ class RecomputeTest(test.TestCase): def fn_recompute(x): return fn(x) + @rev_block_lib.recompute_grad(use_data_dep=True) + def fn_use_data_dep(x): + return fn(x) + + @rev_block_lib.recompute_grad(tupleize_grads=True) + def fn_tupleize(x): + return fn(x) + + @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True) + def fn_both(x): + return fn(x) + x = random_ops.random_uniform((3, 1, 3)) - recompute_vars = None - with variable_scope.variable_scope("recompute") as vs: - out1 = math_ops.reduce_sum(fn_recompute(x)) - recompute_vars = vs.trainable_variables() - reg_vars = None - with variable_scope.variable_scope("regular") as vs: - out2 = math_ops.reduce_sum(fn(x)) - reg_vars = vs.trainable_variables() - - grad1 = gradients_impl.gradients(out1, recompute_vars) - grad2 = gradients_impl.gradients(out2, reg_vars) + + names_and_fns = [ + ("recompute", fn_recompute), + ("regular", fn), + ("use_data_dep", fn_use_data_dep), + ("tupleize", fn_tupleize), + ("tuple_and_data_dep", fn_both), + ] + outputs_and_vars = [] + for name, wrapped_fn in names_and_fns: + with variable_scope.variable_scope(name) as vs: + out = math_ops.reduce_sum(wrapped_fn(x)) + outputs_and_vars.append((out, vs.trainable_variables())) + + all_grads = [] + for out, scope_vars in outputs_and_vars: + all_grads.append(gradients_impl.gradients(out, scope_vars)) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - outs = sess.run([out1, out2, grad1, grad2]) - self.assertAllClose(outs[0], outs[1]) - for g1, g2 in zip(outs[2], outs[3]): - self.assertAllClose(g1, g2) + outputs = list(zip(*outputs_and_vars))[0] + outs, all_grads_val = sess.run([outputs, all_grads]) + + # All outputs are the same + current = outs[0] + for out in outs[1:]: + self.assertAllClose(current, out) + current = out + + # All gradients are the same + for grads in zip(all_grads_val): + current = grads[0] + for g in grads[1:]: + self.assertAllClose(current, g) + current = g class FnWithCustomGradTest(test.TestCase): |