diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/segment_reduction_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index d7e3b3e79b..485530d405 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase): slice_shape = x.shape[indices.ndim:] x_flat = x.reshape((indices.size,) + slice_shape) for i, index in enumerate(indices.ravel()): - if output[index] is not None: + if (output[index] is not None) and op1 == np.max: + for j in range(0, output[index].shape[0]): + output[index][j] = op1([output[index][j], x_flat[i][j]]) + elif output[index] is not None: output[index] = op1(output[index], x_flat[i]) else: output[index] = x_flat[i] # zero initialize values that are still uncalcuated. - output = [o if o is not None else np.zeros(slice_shape) for o in output] + # output = [o if o is not None else np.zeros(slice_shape) for o in output] + if not op1 == np.max: + output = [o if o is not None else np.zeros(slice_shape) for o in output] + else: + zeroslice = np.zeros(slice_shape) + zeroslice.fill(dtype.min) + output = [o if o is not None else zeroslice for o in output] if op2 is not None: output = [op2(o) for o in output] output = [o.reshape(slice_shape) for o in output] @@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): self._assertAllClose(indices, np_ans, tf_ans) self.assertShapeEqual(np_ans, s) - def testGradient(self): + def testGradientSegmentSum(self): num_cols = 2 indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = max(indices_flat) + 3 @@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype)) + def testGradientSegmentMax(self): + num_cols = 2 + indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + num_segments = max(indices_flat) + 3 + for indices in indices_flat, indices_flat.reshape(5, 2): + shape = indices.shape + (num_cols,) + with self.test_session(): + tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) + s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices, + num_segments=num_segments) + jacob_t, jacob_n = gradient_checker.compute_gradient( + tf_x, + shape, + s, + [num_segments, num_cols], + x_init_value=np_x.astype(np.double), delta=1) + self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest): use_gpu = True |