aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/segment_reduction_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py165
1 files changed, 100 insertions, 65 deletions
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 239a48d273..3bca5fadc4 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -46,7 +46,8 @@ class SegmentReductionHelper(test.TestCase):
return constant_op.constant(
np_values, shape=input_shape, dtype=dtype), np_values
- def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None):
+ def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None,
+ initial_value=0):
if not x.size:
return np.array([])
indices = np.asarray(indices)
@@ -64,13 +65,8 @@ class SegmentReductionHelper(test.TestCase):
else:
output[index] = x_flat[i]
# zero initialize values that are still uncalcuated.
- # output = [o if o is not None else np.zeros(slice_shape) for o in output]
- if not op1 == np.max:
- output = [o if o is not None else np.zeros(slice_shape) for o in output]
- else:
- zeroslice = np.zeros(slice_shape)
- zeroslice.fill(dtype.min)
- output = [o if o is not None else zeroslice for o in output]
+ initial_value_slice = np.ones(slice_shape) * initial_value
+ output = [o if o is not None else initial_value_slice for o in output]
if op2 is not None:
output = [op2(o) for o in output]
output = [o.reshape(slice_shape) for o in output]
@@ -82,6 +78,9 @@ class SegmentReductionHelper(test.TestCase):
def _mean_reduce_op(self, x):
return x[0] / x[1] if isinstance(x, tuple) else x
+ def _sqrt_n_reduce_op(self, x):
+ return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x
+
class SegmentReductionOpTest(SegmentReductionHelper):
@@ -244,27 +243,61 @@ class SegmentReductionOpTest(SegmentReductionHelper):
self.assertAllClose(jacob_t, jacob_n)
-class UnsortedSegmentSumTest(SegmentReductionHelper):
+class UnsortedSegmentTest(SegmentReductionHelper):
+
+ def __init__(self, methodName='runTest'):
+ # Each item is np_op1, np_op2, tf_op, initial_value functor
+ self.ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0),
+ (self._mean_cum_op, self._mean_reduce_op,
+ math_ops.unsorted_segment_mean, lambda t: 0),
+ (self._mean_cum_op, self._sqrt_n_reduce_op,
+ math_ops.unsorted_segment_sqrt_n, lambda t: 0),
+ (np.ndarray.__mul__, None,
+ math_ops.unsorted_segment_prod, lambda t: 1),
+ (np.minimum, None,
+ math_ops.unsorted_segment_min, lambda t: t.max),
+ (np.maximum, None,
+ math_ops.unsorted_segment_max, lambda t: t.min)]
+
+ # A subset of ops has been enabled for complex numbers
+ self.complex_ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0)]
+ self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
+ dtypes_lib.float64]
+ self.all_dtypes = (self.differentiable_dtypes +
+ [dtypes_lib.bfloat16,
+ dtypes_lib.int64, dtypes_lib.int32,
+ dtypes_lib.complex64, dtypes_lib.complex128])
+ super(UnsortedSegmentTest, self).__init__(methodName=methodName)
def testValues(self):
- dtypes = [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
- dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
- ]
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
num_segments = 12
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (2,)
- for dtype in dtypes:
- with self.test_session(use_gpu=True):
- tf_x, np_x = self._input(shape, dtype=dtype)
- np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_segments=num_segments)
- s = math_ops.unsorted_segment_sum(
- data=tf_x, segment_ids=indices, num_segments=num_segments)
- tf_ans = s.eval()
- self.assertAllClose(np_ans, tf_ans)
- self.assertShapeEqual(np_ans, s)
+ for dtype in self.all_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=True):
+ for np_op1, np_op2, tf_op, init_op in ops_list:
+ # sqrt_n doesn't support integers
+ if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
+ continue
+ # todo(philjd): enable this test once real_div supports bfloat16
+ if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and
+ dtype == dtypes_lib.bfloat16):
+ continue
+ np_ans = self._segmentReduce(
+ indices, np_x, np_op1, np_op2, num_segments=num_segments,
+ initial_value=init_op(dtype))
+ s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments)
+ tf_ans = s.eval()
+ if dtype is dtypes_lib.bfloat16:
+ tf_ans = tf_ans.astype(np.float32)
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
def testNumSegmentsTypes(self):
dtypes = [dtypes_lib.int32, dtypes_lib.int64]
@@ -287,25 +320,51 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
self.assertAllClose(np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
- def testGradientSegmentSum(self):
+ def testGradients(self):
num_cols = 2
- indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
num_segments = max(indices_flat) + 3
- for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
- dtypes_lib.complex128]:
+ for dtype in self.differentiable_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (num_cols,)
- with self.test_session(use_gpu=True):
- tf_x, np_x = self._input(shape, dtype=dtype)
- s = math_ops.unsorted_segment_sum(
- data=tf_x, segment_ids=indices, num_segments=num_segments)
+ # test CPU and GPU as tf.gather behaves differently on each device
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ for _, _, tf_op, _ in ops_list:
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ s = tf_op(tf_x, indices, num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s, [num_segments, num_cols],
+ x_init_value=np_x,
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n)
+
+ def testProdGrad(self):
+ # additional test for the prod gradient to ensure correct handling of zeros
+ values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32)
+ indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32)
+ indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32)
+ values_tf = constant_op.constant(values)
+ # ground truth partial derivatives
+ gradients_indices = np.zeros((9, 3), dtype=np.float32)
+ gradients_indices_neg = np.zeros((9, 3), dtype=np.float32)
+ # the derivative w.r.t. to the other segments is zero, so here we only
+ # explicitly set the grad values for the corresponding segment
+ gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9]
+ gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3]
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ for ind, grad_gt in [(indices, gradients_indices),
+ (indices_neg, gradients_indices_neg)]:
+ s = math_ops.unsorted_segment_prod(values_tf,
+ constant_op.constant(ind), 3)
jacob_t, jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- s, [num_segments, num_cols],
- x_init_value=np_x,
- delta=1)
- self.assertAllClose(jacob_t, jacob_n)
+ values_tf, (9,), s, (3,), x_init_value=values, delta=1)
+ self.assertAllClose(jacob_t, jacob_n)
+ self.assertAllClose(jacob_t, grad_gt)
def testGradientMatchesSegmentSum(self):
# Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
@@ -318,8 +377,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
num_cols = 2
shape = [n, num_cols]
num_segments = max(indices) + 1
- for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
- dtypes_lib.complex128]:
+ for dtype in self.differentiable_dtypes:
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
# Results from UnsortedSegmentSum
@@ -353,9 +411,8 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
unsorted.eval()
def testEmptySecondDimension(self):
- dtypes = [
- np.float32, np.float64, np.int64, np.int32, np.complex64, np.complex128
- ]
+ dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
+ np.complex64, np.complex128]
with self.test_session(use_gpu=True):
for dtype in dtypes:
for itype in (np.int32, np.int64):
@@ -364,36 +421,14 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
- def testGradientSegmentMax(self):
- num_cols = 2
- indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
- num_segments = max(indices_flat) + 3
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (num_cols,)
- with self.test_session(use_gpu=True):
- tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
- s = math_ops.unsorted_segment_max(
- data=tf_x, segment_ids=indices, num_segments=num_segments)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- s,
- [num_segments, num_cols],
- x_init_value=np_x.astype(np.double), delta=1)
- self.assertAllClose(jacob_t, jacob_n)
-
def testDropNegatives(self):
# Note: the test is done by replacing segment_ids with 8 to -1
# for index and replace values generated by numpy with 0.
- dtypes = [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
- dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
- ]
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
num_segments = 12
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (2,)
- for dtype in dtypes:
+ for dtype in self.all_dtypes:
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
np_ans = self._segmentReduce(