diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-28 19:47:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-28 19:50:17 -0700 |
commit | 3a9c513c3f4303e5194474d804367c1f4831e3ee (patch) | |
tree | fcbec0e200d3936fa015a06256b3bfcf7c0ecd82 | |
parent | d07a8d4071b20d10226ea81758c9306ffce21317 (diff) |
Internally rewrite RevBlock to use @custom_gradient
PiperOrigin-RevId: 194679657
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 297 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 96 |
2 files changed, 105 insertions, 288 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 1a439f0a4d..8ed9f446bc 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -35,7 +35,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops @@ -155,7 +154,7 @@ def _scope_wrap(fn, scope): @functools.wraps(fn) def wrap(*args, **kwargs): - with variable_scope.variable_scope(scope): + with variable_scope.variable_scope(scope, use_resource=True): return fn(*args, **kwargs) return wrap @@ -230,95 +229,95 @@ class RevBlock(base.Layer): "build.") self.built = True - def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): - """Custom gradient fn for a block of reversible residual layers.""" - # Inputs have passed through an Identity. Recover the original Tensors to - # be able to match up side inputs. - assert [u"Identity"] == list(set([x.op.type for x in inputs])) - inputs = [x.op.inputs[0] for x in inputs] - side_inputs = inputs[2:] - del inputs - - f_side_idxs = [None] * len(self.f_side_input) - g_side_idxs = [None] * len(self.g_side_input) - assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) - - for i, t in enumerate(side_inputs): - if t in self.f_side_input: - f_side_idxs[self.f_side_input.index(t)] = i - elif t in self.g_side_input: - g_side_idxs[self.g_side_input.index(t)] = i - else: - assert False - - f_vars = [[] for _ in range(self.num_layers)] - g_vars = [[] for _ in range(self.num_layers)] - f_vars_idxs = [[] for _ in range(self.num_layers)] - g_vars_idxs = [[] for _ in range(self.num_layers)] - - for i, ref in enumerate(variables): - # Use the name to identify the layer number and function (f or g) - regex = LAYER_RE.match(ref.name) - layer_no = int(regex.group(1)) - fn_name = regex.group(2) - if fn_name == "f": - f_vars[layer_no].append(ref) - f_vars_idxs[layer_no].append(i) - else: - assert fn_name == "g" - g_vars[layer_no].append(ref) - g_vars_idxs[layer_no].append(i) - - f_var_grads = [] - g_var_grads = [] - f_side_grads = [] - g_side_grads = [] - - # Reverse variable containers to go backward - f_vars.reverse() - g_vars.reverse() - f = list(self.f) - g = list(self.g) - f.reverse() - g.reverse() - - with variable_scope.variable_scope(self.scope_name, reuse=True): - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) - - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) - - # Accumulate layer gradients for f_side_input and g_side_input - acc_f_side_grads = _acc_grads(*f_side_grads) - acc_g_side_grads = _acc_grads(*g_side_grads) - - # Use the stored idxs to put gradients in the passed-in order. - side_input_grads = [None] * len(side_inputs) - variable_grads = [None] * len(variables) - - # Variable gradients were collected in reverse layer order. Reverse to match - # idxs. - f_var_grads.reverse() - g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( - zip(g_vars_idxs, g_var_grads)): - for i, grad in zip(idxs, grads): - variable_grads[i] = grad - - for i, grad in zip(f_side_idxs, acc_f_side_grads): - side_input_grads[i] = grad - for i, grad in zip(g_side_idxs, acc_g_side_grads): - side_input_grads[i] = grad - - grad_x1, grad_x2 = grad_ys - return [grad_x1, grad_x2] + side_input_grads, variable_grads + def _make_efficient_grad_fn(self, inputs_, ys_): + def _efficient_grad_fn(*grad_ys, **kwargs): + """Custom gradient fn for a block of reversible residual layers.""" + inputs = inputs_ + ys = ys_ + variables = kwargs["variables"] + side_inputs = inputs[2:] + + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, ref in enumerate(variables): + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to + # match idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + return _efficient_grad_fn def _forward(self, x1, x2): """Run forward through the reversible layers.""" @@ -326,10 +325,6 @@ class RevBlock(base.Layer): side_inputs = [self.f_side_input, self.g_side_input] flat_side_inputs = nest.flatten(side_inputs) - custom_grad_fn = ( - self._efficient_grad_fn if self._use_efficient_backprop else None) - - @_fn_with_custom_grad(custom_grad_fn) def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( @@ -342,7 +337,16 @@ class RevBlock(base.Layer): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward_wrap(x1, x2, *flat_side_inputs) + @custom_gradient.custom_gradient + def _forward_with_custom_grad(*args): + out = _forward_wrap(*args) # pylint: disable=no-value-for-parameter + grad_fn = self._make_efficient_grad_fn(args, out) + return out, grad_fn + + if self._use_efficient_backprop: + return _forward_with_custom_grad(x1, x2, *flat_side_inputs) + else: + return _forward_wrap(x1, x2, *flat_side_inputs) def _backward(self, y1, y2): """Run backward through the reversible layers.""" @@ -560,107 +564,6 @@ def _underlying_variable_ref(t): return None -def _fn_with_custom_grad(grad_fn, use_global_vars=False): - """Decorator to create a subgraph with a custom gradient function. - - The subgraph created by the decorated function is NOT put in a Defun and so - does not suffer from the limitations of the Defun (all subgraph ops on the - same device, no summaries). - - Args: - grad_fn: function with signature - (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - Decorator for function such that the gradient is defined by grad_fn. - """ - - def dec(fn): - - @functools.wraps(fn) - def wrapped(*args): - return _fn_with_custom_grad_internal( - fn, args, grad_fn, use_global_vars=use_global_vars) - - return wrapped - - return dec - - -def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): - """Create a subgraph with a custom gradient. - - Args: - fn: function that takes inputs as arguments and produces 1 or more Tensors. - inputs: list<Tensor>, will be passed as fn(*inputs). - grad_fn: function with signature - (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - fn(*inputs) - """ - vs = variable_scope.get_variable_scope() - get_vars_fn = ( - vs.global_variables if use_global_vars else vs.trainable_variables) - len_before_vars = len(get_vars_fn()) - inputs = [array_ops.identity(x) for x in inputs] - outputs = fn(*inputs) - train_vars = get_vars_fn()[len_before_vars:] - - if grad_fn is None: - return outputs - - if not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - outputs = list(outputs) - - defun_inputs = [inputs, train_vars, outputs] - - def custom_grad_fn(op, *dys): - """Custom grad fn applying grad_fn for identity Defun.""" - fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( - defun_inputs, list(op.inputs)) - fn_vars = [_underlying_variable_ref(v) for v in fn_vars] - dys = list(dys) - assert len(fn_outputs) == len(outputs) - assert len(fn_outputs) == len(dys) - - grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) - grad_outputs = [None] * len(fn_outputs) - return tuple(grad_inputs + grad_vars + grad_outputs) - - # The Defun takes as input the original inputs, the trainable variables - # created in fn, and the outputs. In the forward it passes through the - # outputs. In the backwards, it produces gradients for the original inputs - # and the trainable variables. - in_types = [t.dtype for t in inputs] - out_types = [t.dtype for t in outputs] - var_types = [t.dtype for t in train_vars] - - # Get a unique name for the Defun - with framework_ops.name_scope("identity_custom_grad") as ns: - defun_name = ns - - @function.Defun( - *(in_types + var_types + out_types), - func_name=defun_name, - python_grad_func=custom_grad_fn, - shape_func=lambda _: [t.get_shape() for t in outputs]) - def identity(*args): - _, _, outs = nest.pack_sequence_as(defun_inputs, args) - return tuple([array_ops.identity(t) for t in outs]) - - flat_inputs = nest.flatten(defun_inputs) - id_out = identity(*flat_inputs) - return id_out - - def _force_data_dependency(first_compute, then_compute): """Force all of `then_compute` to depend on all of `first_compute`. diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 8107486d7d..997f53b9e1 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -83,8 +83,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) - self.assertAllClose(y1, y1_inv) - self.assertAllClose(y2, y2_inv) + self.assertAllClose(y1, y1_inv, rtol=1e-5) + self.assertAllClose(y2, y2_inv, rtol=1e-5) def _testRevBlock(self, x=None, @@ -179,18 +179,16 @@ class RevBlockTest(test.TestCase): self._testRevBlock(f=[f1, f2, f1, f2]) - # TODO(rsepassi): Recent change to conv seems to have broken this test. Find - # out why. - def _testConvAndBatchNorm(self): + def testConvAndBatchNorm(self): x = random_ops.random_uniform( [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) return x self._testRevBlock(x=x, f=f) @@ -345,89 +343,5 @@ class RecomputeTest(test.TestCase): self.assertTrue(grad is not None) -class FnWithCustomGradTest(test.TestCase): - - def testCorrectness(self): - - w = random_ops.random_uniform([6, 10]) - - def fn(a, b, c): - return core_layers.dense( - a, - 10, - use_bias=False, - kernel_initializer=lambda shape, dtype, partition_info: w - ) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, outputs, grad_outputs): - outputs = outputs[0] - grad_outputs = grad_outputs[0] - grad_inputs = gradients_impl.gradients( - outputs, inputs, grad_ys=grad_outputs) - grad_vars = gradients_impl.gradients( - outputs, trainable_variables, grad_ys=grad_outputs) - return grad_inputs, grad_vars - - custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - - out = fn(a, b, c) - custom_out = custom_fn(a, b, c) - self.assertEqual(out.get_shape().as_list(), - custom_out.get_shape().as_list()) - - loss = math_ops.reduce_mean(out) - custom_loss = math_ops.reduce_mean(custom_out) - - grads = gradients_impl.gradients( - loss, [a, b, c] + [variables.trainable_variables()[0]]) - custom_grads = gradients_impl.gradients( - custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - out_val, custom_out_val, grads_val, custom_grads_val = sess.run( - [out, custom_out, grads, custom_grads]) - self.assertAllClose(out_val, custom_out_val) - for g1, g2 in zip(grads_val, custom_grads_val): - self.assertAllClose(g1, g2) - - def testCustomGrad(self): - - def fn(a, b, c): - return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, unused_outputs, - unused_grad_outputs): - grad_inputs = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) - ] - grad_vars = [ - array_ops.ones_like(t) * (i + len(inputs) + 1.) - for i, t in enumerate(trainable_variables) - ] - return grad_inputs, grad_vars - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - w = random_ops.random_uniform([6, 10]) - out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) - loss = math_ops.reduce_mean(out) - grads = gradients_impl.gradients( - loss, [a, b, c, variables.trainable_variables()[0]]) - expected_grads = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) - ] - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - g_val, eg_val = sess.run([grads, expected_grads]) - for g1, g2 in zip(g_val, eg_val): - self.assertAllClose(g1, g2) - - if __name__ == "__main__": test.main() |