aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/rev_block_lib_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index bc09ba8d43..d5971fb9d8 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase):
self.assertEqual(2, len(update_ops))
self.assertEqual([False, True], kwarg_values)
+ def testWithoutVariables(self):
+
+ def concat_n(layer_list, num_inputs):
+ return math_ops.reduce_sum(
+ array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1),
+ axis=1, keepdims=True)
+
+ @rev_block_lib.recompute_grad
+ def concat_n_wrap(*args):
+ return concat_n(args, 3)
+
+ # DenseNet-style layers
+ layer_list = [random_ops.random_uniform((4, 8))]
+ for _ in range(5):
+ layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
+
+ grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
+ with self.test_session() as sess:
+ sess.run(grads)
+
if __name__ == "__main__":
test.main()