aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/split_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/split_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 5f8a3f3ab2..8ea2d7ecda 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -128,6 +128,14 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[:, 0:1], inp_grads[0])
self.assertAllEqual(result[:, 1:4], inp_grads[1])
+ def testOutputShape(self):
+ with self.test_session(use_gpu=False):
+ tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
+ size_splits = [3, 7, 2]
+ outputs = array_ops.split(tensor, size_splits, 1)
+ for i, output in enumerate(outputs):
+ self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
+
def _compare(self, x, dim, num, use_gpu):
np_ans = np.split(x, num, dim)
with self.test_session(use_gpu=use_gpu) as sess: