diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/segment_reduction_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 516a9d000e..3a02f24902 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -323,8 +323,9 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): def testBadIndices(self): # Note: GPU kernel does not return the out-of-range error needed for this # test, so this test is marked as cpu-only. + # Note: With PR #13055 a negative index will be ignored silently. with self.test_session(use_gpu=False): - for bad in [[-1]], [[7]]: + for bad in [[2]], [[7]]: unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2) with self.assertRaisesOpError( r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]): @@ -360,6 +361,32 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): x_init_value=np_x.astype(np.double), delta=1) self.assertAllClose(jacob_t, jacob_n) + def testDropNegatives(self): + # Note: the test is done by replacing segment_ids with 8 to -1 + # for index and replace values generated by numpy with 0. + dtypes = [ + dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, + dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 + ] + indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + num_segments = 12 + for indices in indices_flat, indices_flat.reshape(5, 2): + shape = indices.shape + (2,) + for dtype in dtypes: + with self.test_session(use_gpu=True): + tf_x, np_x = self._input(shape, dtype=dtype) + np_ans = self._segmentReduce( + indices, np_x, np.add, op2=None, num_out_rows=num_segments) + # Replace np_ans[8] with 0 for the value + np_ans[8:] = 0 + # Replace 8 with -1 in indices + np.place(indices, indices==8, [-1]) + s = math_ops.unsorted_segment_sum( + data=tf_x, segment_ids=indices, num_segments=num_segments) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) + self.assertShapeEqual(np_ans, s) + class SparseSegmentReductionHelper(SegmentReductionHelper): |