aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/segment_reduction_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py32
1 files changed, 29 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index d7e3b3e79b..485530d405 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase):
slice_shape = x.shape[indices.ndim:]
x_flat = x.reshape((indices.size,) + slice_shape)
for i, index in enumerate(indices.ravel()):
- if output[index] is not None:
+ if (output[index] is not None) and op1 == np.max:
+ for j in range(0, output[index].shape[0]):
+ output[index][j] = op1([output[index][j], x_flat[i][j]])
+ elif output[index] is not None:
output[index] = op1(output[index], x_flat[i])
else:
output[index] = x_flat[i]
# zero initialize values that are still uncalcuated.
- output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ # output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ if not op1 == np.max:
+ output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ else:
+ zeroslice = np.zeros(slice_shape)
+ zeroslice.fill(dtype.min)
+ output = [o if o is not None else zeroslice for o in output]
if op2 is not None:
output = [op2(o) for o in output]
output = [o.reshape(slice_shape) for o in output]
@@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
self._assertAllClose(indices, np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
- def testGradient(self):
+ def testGradientSegmentSum(self):
num_cols = 2
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
num_segments = max(indices_flat) + 3
@@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
+ def testGradientSegmentMax(self):
+ num_cols = 2
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = max(indices_flat) + 3
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (num_cols,)
+ with self.test_session():
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
+ s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
+ num_segments=num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s,
+ [num_segments, num_cols],
+ x_init_value=np_x.astype(np.double), delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest):
use_gpu = True