aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/unstack_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/unstack_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/unstack_op_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 1ee6e0866a..b373c419b6 100644
--- a/tensorflow/python/kernel_tests/unstack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -99,7 +99,7 @@ class UnstackOpTest(test.TestCase):
self.assertLess(err, 1e-6)
def testInferNum(self):
- with self.test_session():
+ with self.cached_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
x = array_ops.placeholder(np.float32, shape=shape)
cs = array_ops.unstack(x)
@@ -131,13 +131,13 @@ class UnstackOpTest(test.TestCase):
for j in range(-i, i):
expected = np_split_squeeze(a, j)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_unstack = sess.run(array_ops.unstack(a, axis=j))
self.assertAllEqual(expected, actual_unstack)
def testAxis0Default(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
unstacked = sess.run(array_ops.unstack(a))
@@ -156,7 +156,7 @@ class UnstackOpTest(test.TestCase):
array_ops.unstack(a, axis=-3)
def testZeroLengthDim(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros(shape=(0, 1, 2))
y = array_ops.unstack(x, axis=1)[0].eval()
self.assertEqual(y.shape, (0, 2))