aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/manip_ops_test.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/python/kernel_tests/manip_ops_test.py
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/python/kernel_tests/manip_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py55
1 files changed, 47 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index b8200ac0cb..f31426713c 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import manip_ops
from tensorflow.python.platform import test as test_lib
@@ -88,41 +90,78 @@ class RollTest(test_util.TensorFlowTestCase):
x = np.random.rand(3, 2, 1, 1).astype(t)
self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+ def testNegativeAxis(self):
+ self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
+ self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
+ # Make sure negative axis shoudl be 0 <= axis + dims < dims
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
+ 3, -10).eval()
+
+ def testInvalidInputShape(self):
+ # The input should be 1-D or higher, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at least rank 1 but is rank 0"):
+ manip_ops.roll(7, 1, 0)
+
def testRollInputMustVectorHigherRaises(self):
- tensor = 7
+ # The input should be 1-D or higher, checked in kernel.
+ tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
+
+ def testInvalidAxisShape(self):
+ # The axis should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], 1, [[0, 1]])
def testRollAxisMustBeScalarOrVectorRaises(self):
+ # The axis should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
shift = 1
- axis = [[0, 1]]
+ axis = array_ops.placeholder(dtype=dtypes.int32)
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
+
+ def testInvalidShiftShape(self):
+ # The shift should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], [[0, 1]], 1)
def testRollShiftMustBeScalarOrVectorRaises(self):
+ # The shift should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [[0, 1]]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
+
+ def testInvalidShiftAndAxisNotEqualShape(self):
+ # The shift and axis must be same size, checked in shape function.
+ with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
+ manip_ops.roll([[1, 2], [3, 4]], [1], [0, 1])
def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ # The shift and axis must be same size, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [1]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
def testRollAxisOutOfRangeRaises(self):
tensor = [1, 2]