aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-28 19:47:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-28 19:50:17 -0700
commit3a9c513c3f4303e5194474d804367c1f4831e3ee (patch)
treefcbec0e200d3936fa015a06256b3bfcf7c0ecd82
parentd07a8d4071b20d10226ea81758c9306ffce21317 (diff)
Internally rewrite RevBlock to use @custom_gradient
PiperOrigin-RevId: 194679657
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py297
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py96
2 files changed, 105 insertions, 288 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 1a439f0a4d..8ed9f446bc 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -35,7 +35,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
from tensorflow.python.eager import backprop
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
@@ -155,7 +154,7 @@ def _scope_wrap(fn, scope):
@functools.wraps(fn)
def wrap(*args, **kwargs):
- with variable_scope.variable_scope(scope):
+ with variable_scope.variable_scope(scope, use_resource=True):
return fn(*args, **kwargs)
return wrap
@@ -230,95 +229,95 @@ class RevBlock(base.Layer):
"build.")
self.built = True
- def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
- """Custom gradient fn for a block of reversible residual layers."""
- # Inputs have passed through an Identity. Recover the original Tensors to
- # be able to match up side inputs.
- assert [u"Identity"] == list(set([x.op.type for x in inputs]))
- inputs = [x.op.inputs[0] for x in inputs]
- side_inputs = inputs[2:]
- del inputs
-
- f_side_idxs = [None] * len(self.f_side_input)
- g_side_idxs = [None] * len(self.g_side_input)
- assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
-
- for i, t in enumerate(side_inputs):
- if t in self.f_side_input:
- f_side_idxs[self.f_side_input.index(t)] = i
- elif t in self.g_side_input:
- g_side_idxs[self.g_side_input.index(t)] = i
- else:
- assert False
-
- f_vars = [[] for _ in range(self.num_layers)]
- g_vars = [[] for _ in range(self.num_layers)]
- f_vars_idxs = [[] for _ in range(self.num_layers)]
- g_vars_idxs = [[] for _ in range(self.num_layers)]
-
- for i, ref in enumerate(variables):
- # Use the name to identify the layer number and function (f or g)
- regex = LAYER_RE.match(ref.name)
- layer_no = int(regex.group(1))
- fn_name = regex.group(2)
- if fn_name == "f":
- f_vars[layer_no].append(ref)
- f_vars_idxs[layer_no].append(i)
- else:
- assert fn_name == "g"
- g_vars[layer_no].append(ref)
- g_vars_idxs[layer_no].append(i)
-
- f_var_grads = []
- g_var_grads = []
- f_side_grads = []
- g_side_grads = []
-
- # Reverse variable containers to go backward
- f_vars.reverse()
- g_vars.reverse()
- f = list(self.f)
- g = list(self.g)
- f.reverse()
- g.reverse()
-
- with variable_scope.variable_scope(self.scope_name, reuse=True):
- for i in xrange(self.num_layers):
- ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
- ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
- self.g_side_input)
-
- grad_f_vars, grad_f_side = f_ret
- grad_g_vars, grad_g_side = g_ret
- f_var_grads.append(grad_f_vars)
- g_var_grads.append(grad_g_vars)
- f_side_grads.append(grad_f_side)
- g_side_grads.append(grad_g_side)
-
- # Accumulate layer gradients for f_side_input and g_side_input
- acc_f_side_grads = _acc_grads(*f_side_grads)
- acc_g_side_grads = _acc_grads(*g_side_grads)
-
- # Use the stored idxs to put gradients in the passed-in order.
- side_input_grads = [None] * len(side_inputs)
- variable_grads = [None] * len(variables)
-
- # Variable gradients were collected in reverse layer order. Reverse to match
- # idxs.
- f_var_grads.reverse()
- g_var_grads.reverse()
- for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
- zip(g_vars_idxs, g_var_grads)):
- for i, grad in zip(idxs, grads):
- variable_grads[i] = grad
-
- for i, grad in zip(f_side_idxs, acc_f_side_grads):
- side_input_grads[i] = grad
- for i, grad in zip(g_side_idxs, acc_g_side_grads):
- side_input_grads[i] = grad
-
- grad_x1, grad_x2 = grad_ys
- return [grad_x1, grad_x2] + side_input_grads, variable_grads
+ def _make_efficient_grad_fn(self, inputs_, ys_):
+ def _efficient_grad_fn(*grad_ys, **kwargs):
+ """Custom gradient fn for a block of reversible residual layers."""
+ inputs = inputs_
+ ys = ys_
+ variables = kwargs["variables"]
+ side_inputs = inputs[2:]
+
+ f_side_idxs = [None] * len(self.f_side_input)
+ g_side_idxs = [None] * len(self.g_side_input)
+ assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
+
+ for i, t in enumerate(side_inputs):
+ if t in self.f_side_input:
+ f_side_idxs[self.f_side_input.index(t)] = i
+ elif t in self.g_side_input:
+ g_side_idxs[self.g_side_input.index(t)] = i
+ else:
+ assert False
+
+ f_vars = [[] for _ in range(self.num_layers)]
+ g_vars = [[] for _ in range(self.num_layers)]
+ f_vars_idxs = [[] for _ in range(self.num_layers)]
+ g_vars_idxs = [[] for _ in range(self.num_layers)]
+
+ for i, ref in enumerate(variables):
+ # Use the name to identify the layer number and function (f or g)
+ regex = LAYER_RE.match(ref.name)
+ layer_no = int(regex.group(1))
+ fn_name = regex.group(2)
+ if fn_name == "f":
+ f_vars[layer_no].append(ref)
+ f_vars_idxs[layer_no].append(i)
+ else:
+ assert fn_name == "g"
+ g_vars[layer_no].append(ref)
+ g_vars_idxs[layer_no].append(i)
+
+ f_var_grads = []
+ g_var_grads = []
+ f_side_grads = []
+ g_side_grads = []
+
+ # Reverse variable containers to go backward
+ f_vars.reverse()
+ g_vars.reverse()
+ f = list(self.f)
+ g = list(self.g)
+ f.reverse()
+ g.reverse()
+
+ with variable_scope.variable_scope(self.scope_name, reuse=True):
+ for i in xrange(self.num_layers):
+ ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
+ ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
+ self.g_side_input)
+
+ grad_f_vars, grad_f_side = f_ret
+ grad_g_vars, grad_g_side = g_ret
+ f_var_grads.append(grad_f_vars)
+ g_var_grads.append(grad_g_vars)
+ f_side_grads.append(grad_f_side)
+ g_side_grads.append(grad_g_side)
+
+ # Accumulate layer gradients for f_side_input and g_side_input
+ acc_f_side_grads = _acc_grads(*f_side_grads)
+ acc_g_side_grads = _acc_grads(*g_side_grads)
+
+ # Use the stored idxs to put gradients in the passed-in order.
+ side_input_grads = [None] * len(side_inputs)
+ variable_grads = [None] * len(variables)
+
+ # Variable gradients were collected in reverse layer order. Reverse to
+ # match idxs.
+ f_var_grads.reverse()
+ g_var_grads.reverse()
+ for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
+ zip(g_vars_idxs, g_var_grads)):
+ for i, grad in zip(idxs, grads):
+ variable_grads[i] = grad
+
+ for i, grad in zip(f_side_idxs, acc_f_side_grads):
+ side_input_grads[i] = grad
+ for i, grad in zip(g_side_idxs, acc_g_side_grads):
+ side_input_grads[i] = grad
+
+ grad_x1, grad_x2 = grad_ys
+ return [grad_x1, grad_x2] + side_input_grads, variable_grads
+ return _efficient_grad_fn
def _forward(self, x1, x2):
"""Run forward through the reversible layers."""
@@ -326,10 +325,6 @@ class RevBlock(base.Layer):
side_inputs = [self.f_side_input, self.g_side_input]
flat_side_inputs = nest.flatten(side_inputs)
- custom_grad_fn = (
- self._efficient_grad_fn if self._use_efficient_backprop else None)
-
- @_fn_with_custom_grad(custom_grad_fn)
def _forward_wrap(x1_, x2_, *flat_side_inputs):
f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs)
return _rev_block_forward(
@@ -342,7 +337,16 @@ class RevBlock(base.Layer):
g_side_input=g_side,
gate_outputs=self._use_efficient_backprop)
- return _forward_wrap(x1, x2, *flat_side_inputs)
+ @custom_gradient.custom_gradient
+ def _forward_with_custom_grad(*args):
+ out = _forward_wrap(*args) # pylint: disable=no-value-for-parameter
+ grad_fn = self._make_efficient_grad_fn(args, out)
+ return out, grad_fn
+
+ if self._use_efficient_backprop:
+ return _forward_with_custom_grad(x1, x2, *flat_side_inputs)
+ else:
+ return _forward_wrap(x1, x2, *flat_side_inputs)
def _backward(self, y1, y2):
"""Run backward through the reversible layers."""
@@ -560,107 +564,6 @@ def _underlying_variable_ref(t):
return None
-def _fn_with_custom_grad(grad_fn, use_global_vars=False):
- """Decorator to create a subgraph with a custom gradient function.
-
- The subgraph created by the decorated function is NOT put in a Defun and so
- does not suffer from the limitations of the Defun (all subgraph ops on the
- same device, no summaries).
-
- Args:
- grad_fn: function with signature
- (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
- all of which are lists of Tensors.
- use_global_vars: if True, variables will be the global variables created.
- If False, will be the trainable variables.
-
- Returns:
- Decorator for function such that the gradient is defined by grad_fn.
- """
-
- def dec(fn):
-
- @functools.wraps(fn)
- def wrapped(*args):
- return _fn_with_custom_grad_internal(
- fn, args, grad_fn, use_global_vars=use_global_vars)
-
- return wrapped
-
- return dec
-
-
-def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
- """Create a subgraph with a custom gradient.
-
- Args:
- fn: function that takes inputs as arguments and produces 1 or more Tensors.
- inputs: list<Tensor>, will be passed as fn(*inputs).
- grad_fn: function with signature
- (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
- all of which are lists of Tensors.
- use_global_vars: if True, variables will be the global variables created.
- If False, will be the trainable variables.
-
- Returns:
- fn(*inputs)
- """
- vs = variable_scope.get_variable_scope()
- get_vars_fn = (
- vs.global_variables if use_global_vars else vs.trainable_variables)
- len_before_vars = len(get_vars_fn())
- inputs = [array_ops.identity(x) for x in inputs]
- outputs = fn(*inputs)
- train_vars = get_vars_fn()[len_before_vars:]
-
- if grad_fn is None:
- return outputs
-
- if not (isinstance(outputs, tuple) or isinstance(outputs, list)):
- outputs = [outputs]
- outputs = list(outputs)
-
- defun_inputs = [inputs, train_vars, outputs]
-
- def custom_grad_fn(op, *dys):
- """Custom grad fn applying grad_fn for identity Defun."""
- fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
- defun_inputs, list(op.inputs))
- fn_vars = [_underlying_variable_ref(v) for v in fn_vars]
- dys = list(dys)
- assert len(fn_outputs) == len(outputs)
- assert len(fn_outputs) == len(dys)
-
- grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
- grad_outputs = [None] * len(fn_outputs)
- return tuple(grad_inputs + grad_vars + grad_outputs)
-
- # The Defun takes as input the original inputs, the trainable variables
- # created in fn, and the outputs. In the forward it passes through the
- # outputs. In the backwards, it produces gradients for the original inputs
- # and the trainable variables.
- in_types = [t.dtype for t in inputs]
- out_types = [t.dtype for t in outputs]
- var_types = [t.dtype for t in train_vars]
-
- # Get a unique name for the Defun
- with framework_ops.name_scope("identity_custom_grad") as ns:
- defun_name = ns
-
- @function.Defun(
- *(in_types + var_types + out_types),
- func_name=defun_name,
- python_grad_func=custom_grad_fn,
- shape_func=lambda _: [t.get_shape() for t in outputs])
- def identity(*args):
- _, _, outs = nest.pack_sequence_as(defun_inputs, args)
- return tuple([array_ops.identity(t) for t in outs])
-
- flat_inputs = nest.flatten(defun_inputs)
- id_out = identity(*flat_inputs)
- return id_out
-
-
def _force_data_dependency(first_compute, then_compute):
"""Force all of `then_compute` to depend on all of `first_compute`.
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index 8107486d7d..997f53b9e1 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -83,8 +83,8 @@ class RevBlockTest(test.TestCase):
sess.run(variables.global_variables_initializer())
y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
- self.assertAllClose(y1, y1_inv)
- self.assertAllClose(y2, y2_inv)
+ self.assertAllClose(y1, y1_inv, rtol=1e-5)
+ self.assertAllClose(y2, y2_inv, rtol=1e-5)
def _testRevBlock(self,
x=None,
@@ -179,18 +179,16 @@ class RevBlockTest(test.TestCase):
self._testRevBlock(f=[f1, f2, f1, f2])
- # TODO(rsepassi): Recent change to conv seems to have broken this test. Find
- # out why.
- def _testConvAndBatchNorm(self):
+ def testConvAndBatchNorm(self):
x = random_ops.random_uniform(
[self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32)
def f(x):
x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same")
- x = layers.batch_norm(x, is_training=True)
+ x = layers.batch_norm(x, is_training=False)
x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same")
- x = layers.batch_norm(x, is_training=True)
+ x = layers.batch_norm(x, is_training=False)
return x
self._testRevBlock(x=x, f=f)
@@ -345,89 +343,5 @@ class RecomputeTest(test.TestCase):
self.assertTrue(grad is not None)
-class FnWithCustomGradTest(test.TestCase):
-
- def testCorrectness(self):
-
- w = random_ops.random_uniform([6, 10])
-
- def fn(a, b, c):
- return core_layers.dense(
- a,
- 10,
- use_bias=False,
- kernel_initializer=lambda shape, dtype, partition_info: w
- ) + math_ops.matmul(b, c)
-
- def grad_fn(inputs, trainable_variables, outputs, grad_outputs):
- outputs = outputs[0]
- grad_outputs = grad_outputs[0]
- grad_inputs = gradients_impl.gradients(
- outputs, inputs, grad_ys=grad_outputs)
- grad_vars = gradients_impl.gradients(
- outputs, trainable_variables, grad_ys=grad_outputs)
- return grad_inputs, grad_vars
-
- custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)
-
- a = random_ops.random_uniform([11, 6])
- b = random_ops.random_uniform([11, 7])
- c = random_ops.random_uniform([7, 10])
-
- out = fn(a, b, c)
- custom_out = custom_fn(a, b, c)
- self.assertEqual(out.get_shape().as_list(),
- custom_out.get_shape().as_list())
-
- loss = math_ops.reduce_mean(out)
- custom_loss = math_ops.reduce_mean(custom_out)
-
- grads = gradients_impl.gradients(
- loss, [a, b, c] + [variables.trainable_variables()[0]])
- custom_grads = gradients_impl.gradients(
- custom_loss, [a, b, c] + [variables.trainable_variables()[1]])
-
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- out_val, custom_out_val, grads_val, custom_grads_val = sess.run(
- [out, custom_out, grads, custom_grads])
- self.assertAllClose(out_val, custom_out_val)
- for g1, g2 in zip(grads_val, custom_grads_val):
- self.assertAllClose(g1, g2)
-
- def testCustomGrad(self):
-
- def fn(a, b, c):
- return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c)
-
- def grad_fn(inputs, trainable_variables, unused_outputs,
- unused_grad_outputs):
- grad_inputs = [
- array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)
- ]
- grad_vars = [
- array_ops.ones_like(t) * (i + len(inputs) + 1.)
- for i, t in enumerate(trainable_variables)
- ]
- return grad_inputs, grad_vars
-
- a = random_ops.random_uniform([11, 6])
- b = random_ops.random_uniform([11, 7])
- c = random_ops.random_uniform([7, 10])
- w = random_ops.random_uniform([6, 10])
- out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c)
- loss = math_ops.reduce_mean(out)
- grads = gradients_impl.gradients(
- loss, [a, b, c, variables.trainable_variables()[0]])
- expected_grads = [
- array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w])
- ]
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- g_val, eg_val = sess.run([grads, expected_grads])
- for g1, g2 in zip(g_val, eg_val):
- self.assertAllClose(g1, g2)
-
-
if __name__ == "__main__":
test.main()