aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_stitch_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index c4d4ce780b..49b9569e2b 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -104,6 +104,27 @@ class DynamicStitchTestBase(object):
# Dimension 0 is max(flatten(indices))+1.
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+ def testZeroSizeTensor(self):
+ with self.test_session(use_gpu=True):
+ indices = [
+ constant_op.constant([0, 4, 7]),
+ constant_op.constant([1, 6]),
+ constant_op.constant([2, 3, 5]),
+ array_ops.zeros([0], dtype=dtypes.int32)
+ ]
+ data = [
+ constant_op.constant([[0, 1], [40, 41], [70, 71]]),
+ constant_op.constant([[10, 11], [60, 61]]),
+ constant_op.constant([[20, 21], [30, 31], [50, 51]]),
+ array_ops.zeros([0, 2], dtype=dtypes.int32)
+ ]
+ stitched_t = self.stitch_op(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
+ [50, 51], [60, 61], [70, 71]], stitched_val)
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
indices = [