aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-07-27 12:05:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 12:12:34 -0700
commit78d225ef8a6a32423febc67803fabdff05b378c0 (patch)
treefdd5b92a9fca5d58df397dfb6427e943a6b62642 /tensorflow/contrib/eager
parentf6cc77189d84328425f83325be5ff428835bb680 (diff)
Update backward pass to save memory in graph mode.
PiperOrigin-RevId: 206352708
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py61
1 files changed, 50 insertions, 11 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index 63e86803ef..89712f2c45 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -99,16 +99,11 @@ class RevBlock(tf.keras.Model):
block = self.blocks[i]
if i == 0:
# First block usually contains downsampling that can't be reversed
- with tf.GradientTape() as tape:
- tape.watch(x)
- y = block(x, training=training)
- grads_combined = tape.gradient(
- y, [x] + block.trainable_variables, output_gradients=dy)
- dy = grads_combined[0]
- grads_all = grads_combined[1:] + grads_all
+ dy, grads = block.backward_grads_with_downsample(
+ x, y, dy, training=True)
else:
y, dy, grads = block.backward_grads(y, dy, training=training)
- grads_all = grads + grads_all
+ grads_all = grads + grads_all
return dy, grads_all
@@ -201,16 +196,21 @@ class _Residual(tf.keras.Model):
gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
dg = grads_combined[1:]
dx1 = dy1 + grads_combined[0]
- x2 = y2 - gy1
+ # This doesn't affect eager execution, but improves memory efficiency with
+ # graphs
+ with tf.control_dependencies(dg + [dx1]):
+ x2 = y2 - gy1
with tf.GradientTape() as ftape:
ftape.watch(x2)
fx2 = self.f(x2, training=training)
grads_combined = ftape.gradient(
fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
- dx2 = dy2 + grads_combined[0]
df = grads_combined[1:]
- x1 = y1 - fx2
+ dx2 = dy2 + grads_combined[0]
+ # Same behavior as above
+ with tf.control_dependencies(df + [dx2]):
+ x1 = y1 - fx2
x = tf.concat([x1, x2], axis=self.axis)
dx = tf.concat([dx1, dx2], axis=self.axis)
@@ -218,6 +218,45 @@ class _Residual(tf.keras.Model):
return x, dx, grads
+ def backward_grads_with_downsample(self, x, y, dy, training=True):
+ """Manually compute backward gradients given input and output grads."""
+ # Splitting this from `backward_grads` for better readability
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
+ y1, _ = tf.split(y, num_or_size_splits=2, axis=self.axis)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
+
+ with tf.GradientTape() as gtape:
+ gtape.watch(y1)
+ gy1 = self.g(y1, training=training)
+ grads_combined = gtape.gradient(
+ gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
+ dg = grads_combined[1:]
+ dz1 = dy1 + grads_combined[0]
+
+ # dx1 need one more step to backprop through downsample
+ with tf.GradientTape() as x1tape:
+ x1tape.watch(x1)
+ z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis)
+ dx1 = x1tape.gradient(z1, x1, output_gradients=dz1)
+
+ with tf.GradientTape() as ftape:
+ ftape.watch(x2)
+ fx2 = self.f(x2, training=training)
+ grads_combined = ftape.gradient(
+ fx2, [x2] + self.f.trainable_variables, output_gradients=dz1)
+ dx2, df = grads_combined[0], grads_combined[1:]
+
+ # dx2 need one more step to backprop through downsample
+ with tf.GradientTape() as x2tape:
+ x2tape.watch(x2)
+ z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis)
+ dx2 += x2tape.gradient(z2, x2, output_gradients=dy2)
+
+ dx = tf.concat([dx1, dx2], axis=self.axis)
+ grads = df + dg
+
+ return dx, grads
+
# Ideally, the following should be wrapped in `tf.keras.Sequential`, however
# there are subtle issues with its placeholder insertion policy and batch norm