aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/rev_block_lib_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index c34b5a8017..2c7463acc0 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase):
y1, y2 = block.forward(x1, x2)
x1_inv, x2_inv = block.backward(y1, y2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv])
@@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase):
x1, x2 = block.backward(y1, y2)
y1_inv, y2_inv = block.forward(x1, x2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
@@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase):
grads_rev = gradients_impl.gradients(loss_rev, wrt)
grads = gradients_impl.gradients(loss, wrt)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
@@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase):
for out, scope_vars in outputs_and_vars:
all_grads.append(gradients_impl.gradients(out, scope_vars))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = list(zip(*outputs_and_vars))[0]
outs, all_grads_val = sess.run([outputs, all_grads])
@@ -389,7 +389,7 @@ class RecomputeTest(test.TestCase):
layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(grads)
def testErrorOnClosedOverTensor(self):