diff options
author | 2018-07-27 12:05:43 -0700 | |
---|---|---|
committer | 2018-07-27 12:12:34 -0700 | |
commit | 78d225ef8a6a32423febc67803fabdff05b378c0 (patch) | |
tree | fdd5b92a9fca5d58df397dfb6427e943a6b62642 /tensorflow/contrib/eager | |
parent | f6cc77189d84328425f83325be5ff428835bb680 (diff) |
Update backward pass to save memory in graph mode.
PiperOrigin-RevId: 206352708
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/blocks.py | 61 |
1 files changed, 50 insertions, 11 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 63e86803ef..89712f2c45 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -99,16 +99,11 @@ class RevBlock(tf.keras.Model): block = self.blocks[i] if i == 0: # First block usually contains downsampling that can't be reversed - with tf.GradientTape() as tape: - tape.watch(x) - y = block(x, training=training) - grads_combined = tape.gradient( - y, [x] + block.trainable_variables, output_gradients=dy) - dy = grads_combined[0] - grads_all = grads_combined[1:] + grads_all + dy, grads = block.backward_grads_with_downsample( + x, y, dy, training=True) else: y, dy, grads = block.backward_grads(y, dy, training=training) - grads_all = grads + grads_all + grads_all = grads + grads_all return dy, grads_all @@ -201,16 +196,21 @@ class _Residual(tf.keras.Model): gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) dg = grads_combined[1:] dx1 = dy1 + grads_combined[0] - x2 = y2 - gy1 + # This doesn't affect eager execution, but improves memory efficiency with + # graphs + with tf.control_dependencies(dg + [dx1]): + x2 = y2 - gy1 with tf.GradientTape() as ftape: ftape.watch(x2) fx2 = self.f(x2, training=training) grads_combined = ftape.gradient( fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) - dx2 = dy2 + grads_combined[0] df = grads_combined[1:] - x1 = y1 - fx2 + dx2 = dy2 + grads_combined[0] + # Same behavior as above + with tf.control_dependencies(df + [dx2]): + x1 = y1 - fx2 x = tf.concat([x1, x2], axis=self.axis) dx = tf.concat([dx1, dx2], axis=self.axis) @@ -218,6 +218,45 @@ class _Residual(tf.keras.Model): return x, dx, grads + def backward_grads_with_downsample(self, x, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + # Splitting this from `backward_grads` for better readability + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + y1, _ = tf.split(y, num_or_size_splits=2, axis=self.axis) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dz1 = dy1 + grads_combined[0] + + # dx1 need one more step to backprop through downsample + with tf.GradientTape() as x1tape: + x1tape.watch(x1) + z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) + dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) + + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2, df = grads_combined[0], grads_combined[1:] + + # dx2 need one more step to backprop through downsample + with tf.GradientTape() as x2tape: + x2tape.watch(x2) + z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) + dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) + + dx = tf.concat([dx1, dx2], axis=self.axis) + grads = df + dg + + return dx, grads + # Ideally, the following should be wrapped in `tf.keras.Sequential`, however # there are subtle issues with its placeholder insertion policy and batch norm |