aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-08-08 16:45:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 16:49:17 -0700
commitaacb29a4ab88f9fa27c3301977e7f2cc289a3976 (patch)
tree9952ea2f753e187fa524c188f8de1f2aa83e5768 /tensorflow/contrib/eager
parente6921fdc23d020fd24781c8757b97e2877ea491e (diff)
Add a unit test from the blog post code demonstration.
PiperOrigin-RevId: 207968029
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
index fda9020ddf..9ff6b605b9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -188,6 +188,40 @@ class RevBlockTest(tf.test.TestCase):
self._check_grad_angle(dx_true, dx)
self._check_grad_angle(dw_true, dw)
+ def test_backward_grads_with_nativepy(self):
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ input_shape = (128, 8, 8)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1)
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ fused=False,
+ dtype=tf.float64)
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=1)
+ y1, y2 = block((x1, x2), training=True)
+ y = tf.concat((y1, y2), axis=1)
+
+ # Compute true grads
+ dx_true = tape.gradient(y, x, output_gradients=dy)
+
+ # Compute grads from reconstruction
+ (dx1, dx2), _ = block.backward_grads(
+ x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True)
+ dx = tf.concat((dx1, dx2), axis=1)
+
+ thres = 1e-5
+ diff_abs = tf.reshape(abs(dx - dx_true), [-1])
+ assert all(diff_abs < thres)
+
class _ResidualTest(tf.test.TestCase):