aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/sort_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/sort_ops_test.py')
-rw-r--r--tensorflow/contrib/framework/python/ops/sort_ops_test.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
index a8fb94b245..791b32cd1e 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
@@ -48,7 +48,7 @@ class SortTest(test.TestCase):
sort_axis = np.random.choice(rank)
if negative_axis:
sort_axis = -1 - sort_axis
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -60,7 +60,7 @@ class SortTest(test.TestCase):
shape = [np.random.randint(1, 4) for _ in range(rank)]
arr = np.random.random(shape)
sort_axis = np.random.choice(rank)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -73,7 +73,7 @@ class SortTest(test.TestCase):
scalar = array_ops.zeros(zeros_length_1)
sort = sort_ops.sort(scalar)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors.InvalidArgumentError):
sort.eval()
@@ -84,7 +84,7 @@ class SortTest(test.TestCase):
def testDescending(self):
arr = np.random.random((10, 5, 5))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=0)[::-1],
sort_ops.sort(
@@ -111,7 +111,7 @@ class SortTest(test.TestCase):
def testArgsort_1d(self):
arr = np.random.random(42)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr),
array_ops.gather(arr, sort_ops.argsort(arr)).eval())
@@ -119,7 +119,7 @@ class SortTest(test.TestCase):
def testArgsort(self):
arr = np.random.random((5, 6, 7, 8))
for axis in range(4):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.argsort(arr, axis=axis),
sort_ops.argsort(arr, axis=axis).eval())