diff options
author | Xuechen Li <lxuechen@google.com> | 2018-08-08 16:45:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 16:49:17 -0700 |
commit | aacb29a4ab88f9fa27c3301977e7f2cc289a3976 (patch) | |
tree | 9952ea2f753e187fa524c188f8de1f2aa83e5768 /tensorflow/contrib/eager | |
parent | e6921fdc23d020fd24781c8757b97e2877ea491e (diff) |
Add a unit test from the blog post code demonstration.
PiperOrigin-RevId: 207968029
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/blocks_test.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index fda9020ddf..9ff6b605b9 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -188,6 +188,40 @@ class RevBlockTest(tf.test.TestCase): self._check_grad_angle(dx_true, dx) self._check_grad_angle(dw_true, dw) + def test_backward_grads_with_nativepy(self): + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + + # Compute true grads + dx_true = tape.gradient(y, x, output_gradients=dy) + + # Compute grads from reconstruction + (dx1, dx2), _ = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + + thres = 1e-5 + diff_abs = tf.reshape(abs(dx - dx_true), [-1]) + assert all(diff_abs < thres) + class _ResidualTest(tf.test.TestCase): |