aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/slice_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/slice_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 051a25080b..6cdc7872f9 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -217,6 +217,30 @@ class SliceTest(test.TestCase):
self.assertEqual(expected_val.shape, slice_t.get_shape())
self.assertEqual(expected_val.shape, slice2_t.get_shape())
+ def testRandomHighRank(self):
+ # Random dims of rank 8
+ input_shape = np.random.randint(0, 20, size=8)
+ inp = np.random.rand(*input_shape).astype("f")
+ with self.test_session(use_gpu=True) as sess:
+ a = constant_op.constant(
+ [float(x) for x in inp.ravel(order="C")],
+ shape=input_shape,
+ dtype=dtypes.float32)
+ indices = [0 if x == 0 else np.random.randint(x) for x in input_shape]
+ sizes = [
+ np.random.randint(0, input_shape[i] - indices[i] + 1)
+ for i in range(8)
+ ]
+ slice_t = array_ops.slice(a, indices, sizes)
+ slice_val = sess.run(slice_t)
+
+ expected_val = inp[indices[0]:indices[0] + sizes[0], indices[1]:indices[1] + sizes[
+ 1], indices[2]:indices[2] + sizes[2], indices[3]:indices[3] + sizes[3], indices[
+ 4]:indices[4] + sizes[4], indices[5]:indices[5] + sizes[5], indices[6]:indices[
+ 6] + sizes[6], indices[7]:indices[7] + sizes[7]]
+ self.assertAllEqual(slice_val, expected_val)
+ self.assertEqual(expected_val.shape, slice_t.get_shape())
+
def testPartialShapeInference(self):
z = array_ops.zeros((1, 2, 3))
self.assertAllEqual(z.get_shape().as_list(), [1, 2, 3])
@@ -227,7 +251,6 @@ class SliceTest(test.TestCase):
m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1])
self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None])
-
def _testGradientSlice(self, input_shape, slice_begin, slice_size):
with self.test_session(use_gpu=True):
num_inputs = np.prod(input_shape)