diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/manip_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/manip_ops_test.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py index dc3ea38671..f71857a3cb 100644 --- a/tensorflow/python/kernel_tests/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/manip_ops_test.py @@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase): def _testRoll(self, np_input, shift, axis): expected_roll = np.roll(np_input, shift, axis) - with self.test_session(): + with self.cached_session(): roll = manip_ops.roll(np_input, shift, axis) self.assertAllEqual(roll.eval(), expected_roll) def _testGradient(self, np_input, shift, axis): - with self.test_session(): + with self.cached_session(): inx = constant_op.constant(np_input.tolist()) xs = list(np_input.shape) y = manip_ops.roll(inx, shift, axis) @@ -94,7 +94,7 @@ class RollTest(test_util.TensorFlowTestCase): self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1) self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2) # Make sure negative axis should be 0 <= axis + dims < dims - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of range"): manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), @@ -111,7 +111,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(dtype=dtypes.int32) shift = 1 axis = 0 - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "input must be 1-D or higher"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7}) @@ -127,7 +127,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = 1 axis = array_ops.placeholder(dtype=dtypes.int32) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "axis must be a scalar or a 1-D vector"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]}) @@ -143,7 +143,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = array_ops.placeholder(dtype=dtypes.int32) axis = 1 - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "shift must be a scalar or a 1-D vector"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]}) @@ -158,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = array_ops.placeholder(dtype=dtypes.int32) axis = [0, 1] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "shift and axis must have the same size"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]}) @@ -167,7 +167,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [1, 2] shift = 1 axis = 1 - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of range"): manip_ops.roll(tensor, shift, axis).eval() |