diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/rev_block_lib_test.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib_test.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index c34b5a8017..2c7463acc0 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase): y1, y2 = block.forward(x1, x2) x1_inv, x2_inv = block.backward(y1, y2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) @@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase): x1, x2 = block.backward(y1, y2) y1_inv, y2_inv = block.forward(x1, x2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) @@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase): grads_rev = gradients_impl.gradients(loss_rev, wrt) grads = gradients_impl.gradients(loss, wrt) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) @@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase): for out, scope_vars in outputs_and_vars: all_grads.append(gradients_impl.gradients(out, scope_vars)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) outputs = list(zip(*outputs_and_vars))[0] outs, all_grads_val = sess.run([outputs, all_grads]) @@ -389,7 +389,7 @@ class RecomputeTest(test.TestCase): layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list))) grads = gradients_impl.gradients(layer_list[-1], layer_list[0]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(grads) def testErrorOnClosedOverTensor(self): |