aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/in_topk_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/in_topk_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index fafeea8ec0..6fdb497bc6 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -30,7 +30,7 @@ class InTopKTest(test.TestCase):
def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected)
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
@@ -65,7 +65,7 @@ class InTopKTest(test.TestCase):
def testBadTarget(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 80000]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"target.*out of range"):
nn_ops.in_top_k(predictions, target, 2).eval()
@@ -75,7 +75,7 @@ class InTopKTest(test.TestCase):
target = [0, 2]
k = constant_op.constant(3)
np_ans = np.array([False, True])
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)