aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/manip_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/manip_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index dc3ea38671..f71857a3cb 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase):
def _testRoll(self, np_input, shift, axis):
expected_roll = np.roll(np_input, shift, axis)
- with self.test_session():
+ with self.cached_session():
roll = manip_ops.roll(np_input, shift, axis)
self.assertAllEqual(roll.eval(), expected_roll)
def _testGradient(self, np_input, shift, axis):
- with self.test_session():
+ with self.cached_session():
inx = constant_op.constant(np_input.tolist())
xs = list(np_input.shape)
y = manip_ops.roll(inx, shift, axis)
@@ -94,7 +94,7 @@ class RollTest(test_util.TensorFlowTestCase):
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
# Make sure negative axis should be 0 <= axis + dims < dims
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
@@ -111,7 +111,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@@ -127,7 +127,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = 1
axis = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@@ -143,7 +143,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@@ -158,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@@ -167,7 +167,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [1, 2]
shift = 1
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(tensor, shift, axis).eval()