aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-05 09:01:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 09:05:49 -0800
commit5e53ba5a33ee116179bc4ac4f09be76811eb3960 (patch)
tree5e6c0d5351b4cc8aabef7be1283555843c0eee20
parent3a2e7635e69b5b1d1f510108d7a601edc570abc8 (diff)
Fix a case in SparseSegmentReduction ops with missing segment IDs, where all segment IDs are empty. Added a test for this case.
PiperOrigin-RevId: 187873356
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc7
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py19
2 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 27b8081eb8..bbf8696531 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -616,7 +616,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
// we need to explicitly set missing indices to the default value.
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- if (num_indices == 0) return;
+ if (num_indices == 0) {
+ if (output_rows > 0) {
+ output->flat_outer_dims<T>().setConstant(default_value_);
+ }
+ return;
+ }
OP_REQUIRES(context, output_rows > 0,
errors::InvalidArgument("segment ids must be >= 0"));
auto output_flat = output->flat_outer_dims<T>();
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 5a54f448d0..239a48d273 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -507,6 +507,25 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
tf_ans = s.eval()
self.assertAllClose(np_ans, tf_ans)
+ def testWithEmptySegments(self):
+ tf_x = constant_op.constant([], shape=[0, 4], dtype=dtypes_lib.float32)
+ ops_list = [
+ math_ops.sparse_segment_sum_with_num_segments,
+ math_ops.sparse_segment_mean_with_num_segments
+ ]
+ segment_indices = []
+ tf_indices = []
+ num_segments = 5
+ with self.test_session(use_gpu=False):
+ for tf_op in ops_list:
+ s = tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+ tf_ans = s.eval()
+ self.assertAllClose(np.zeros([5, 4]), tf_ans)
+
def testSegmentIdsGreaterThanZero(self):
tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
ops_list = [(np.add, None, math_ops.sparse_segment_sum), (