diff options
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/sort_ops_test.py')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/sort_ops_test.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index a8fb94b245..791b32cd1e 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -48,7 +48,7 @@ class SortTest(test.TestCase): sort_axis = np.random.choice(rank) if negative_axis: sort_axis = -1 - sort_axis - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=sort_axis), sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) @@ -60,7 +60,7 @@ class SortTest(test.TestCase): shape = [np.random.randint(1, 4) for _ in range(rank)] arr = np.random.random(shape) sort_axis = np.random.choice(rank) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=sort_axis), sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) @@ -73,7 +73,7 @@ class SortTest(test.TestCase): scalar = array_ops.zeros(zeros_length_1) sort = sort_ops.sort(scalar) - with self.test_session(): + with self.cached_session(): with self.assertRaises(errors.InvalidArgumentError): sort.eval() @@ -84,7 +84,7 @@ class SortTest(test.TestCase): def testDescending(self): arr = np.random.random((10, 5, 5)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=0)[::-1], sort_ops.sort( @@ -111,7 +111,7 @@ class SortTest(test.TestCase): def testArgsort_1d(self): arr = np.random.random(42) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr), array_ops.gather(arr, sort_ops.argsort(arr)).eval()) @@ -119,7 +119,7 @@ class SortTest(test.TestCase): def testArgsort(self): arr = np.random.random((5, 6, 7, 8)) for axis in range(4): - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.argsort(arr, axis=axis), sort_ops.argsort(arr, axis=axis).eval()) |