diff options
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.cc | 7 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 19 |
2 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 27b8081eb8..bbf8696531 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -616,7 +616,12 @@ class SparseSegmentReductionOpBase : public OpKernel { // we need to explicitly set missing indices to the default value. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - if (num_indices == 0) return; + if (num_indices == 0) { + if (output_rows > 0) { + output->flat_outer_dims<T>().setConstant(default_value_); + } + return; + } OP_REQUIRES(context, output_rows > 0, errors::InvalidArgument("segment ids must be >= 0")); auto output_flat = output->flat_outer_dims<T>(); diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 5a54f448d0..239a48d273 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -507,6 +507,25 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): tf_ans = s.eval() self.assertAllClose(np_ans, tf_ans) + def testWithEmptySegments(self): + tf_x = constant_op.constant([], shape=[0, 4], dtype=dtypes_lib.float32) + ops_list = [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments + ] + segment_indices = [] + tf_indices = [] + num_segments = 5 + with self.test_session(use_gpu=False): + for tf_op in ops_list: + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + tf_ans = s.eval() + self.assertAllClose(np.zeros([5, 4]), tf_ans) + def testSegmentIdsGreaterThanZero(self): tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) ops_list = [(np.add, None, math_ops.sparse_segment_sum), ( |