diff options
author | 2017-12-06 12:02:02 -0800 | |
---|---|---|
committer | 2017-12-06 12:06:12 -0800 | |
commit | 91c75ecc66c630f541a2215844b2012b9f5e6df6 (patch) | |
tree | f53b18b1c7bd6d3d945d5ec2abab1095525d08e3 | |
parent | 93aeebad51f29e6d90d091be6e28986079805d3a (diff) |
Allow SparseSegmentReduction ops to have missing segment IDs.
PiperOrigin-RevId: 178131721
9 files changed, 702 insertions, 45 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt new file mode 100644 index 0000000000..d6e1054003 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "SparseSegmentMeanWithNumSegments" + in_arg { + name: "indices" + description: <<END +A 1-D tensor. Has same rank as `segment_ids`. +END + } + in_arg { + name: "segment_ids" + description: <<END +A 1-D tensor. Values should be sorted and can be repeated. +END + } + in_arg { + name: "num_segments" + description: <<END +Should equal the number of distinct segment IDs. +END + } + out_arg { + name: "output" + description: <<END +Has same shape as data, except for dimension 0 which has size +`num_segments`. +END + } + summary: "Computes the mean along sparse segments of a tensor." + description: <<END +Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt new file mode 100644 index 0000000000..9ba98b8191 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt @@ -0,0 +1,38 @@ +op { + graph_op_name: "SparseSegmentSqrtNWithNumSegments" + in_arg { + name: "indices" + description: <<END +A 1-D tensor. Has same rank as `segment_ids`. +END + } + in_arg { + name: "segment_ids" + description: <<END +A 1-D tensor. Values should be sorted and can be repeated. +END + } + in_arg { + name: "num_segments" + description: <<END +Should equal the number of distinct segment IDs. +END + } + out_arg { + name: "output" + description: <<END +Has same shape as data, except for dimension 0 which +has size `k`, the number of segments. +END + } + summary: "Computes the sum along sparse segments of a tensor divided by the sqrt of N." + description: <<END +N is the size of the segment being reduced. + +Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt new file mode 100644 index 0000000000..3aeaba38e9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt @@ -0,0 +1,57 @@ +op { + graph_op_name: "SparseSegmentSumWithNumSegments" + in_arg { + name: "indices" + description: <<END +A 1-D tensor. Has same rank as `segment_ids`. +END + } + in_arg { + name: "segment_ids" + description: <<END +A 1-D tensor. Values should be sorted and can be repeated. +END + } + in_arg { + name: "num_segments" + description: <<END +Should equal the number of distinct segment IDs. +END + } + out_arg { + name: "output" + description: <<END +Has same shape as data, except for dimension 0 which +has size `num_segments`. +END + } + summary: "Computes the sum along sparse segments of a tensor." + description: <<END +Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +For example: + +```python +c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + +tf.sparse_segment_sum_with_num_segments( + c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +# => [[0 0 0 0] +# [0 0 0 0] +# [0 0 0 0]] + +tf.sparse_segment_sum_with_num_segments(c, + tf.constant([0, 1]), + tf.constant([0, 2], + num_segments=4)) +# => [[ 1 2 3 4] +# [ 0 0 0 0] +# [-1 -2 -3 -4] +# [ 0 0 0 0]] +``` +END +} diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 2334e50f1d..3ef1cd1e06 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -553,10 +553,11 @@ class SparseSegmentReductionOpBase : public OpKernel { public: explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, bool is_mean, bool is_sqrtn, - T default_value) + bool has_num_segments, T default_value) : OpKernel(context), is_mean_(is_mean), is_sqrtn_(is_sqrtn), + has_num_segments_(has_num_segments), default_value_(default_value) {} void Compute(OpKernelContext* context) override { @@ -564,6 +565,19 @@ class SparseSegmentReductionOpBase : public OpKernel { const Tensor& indices = context->input(1); const Tensor& segment_ids = context->input(2); + Index output_rows = -1; + if (has_num_segments_) { + const Tensor& num_segments = context->input(3); + + OP_REQUIRES( + context, num_segments.shape().dims() == 0, + errors::InvalidArgument("num_segments should be a scalar, not shape ", + num_segments.shape().DebugString())); + output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); + OP_REQUIRES(context, output_rows >= 0, + errors::InvalidArgument("segment ids must be >= 0")); + } + OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), errors::InvalidArgument("indices should be a vector.")); OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), @@ -581,10 +595,17 @@ class SparseSegmentReductionOpBase : public OpKernel { const auto segment_vec = segment_ids.vec<OutputRow>(); // Note that the current implementation assumes that segment_vec values are // sorted. - const OutputRow output_rows = + const OutputRow last_segment_id_plus_one = num_indices > 0 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 : 0; + if (has_num_segments_) { + OP_REQUIRES( + context, output_rows >= last_segment_id_plus_one, + errors::InvalidArgument("segment ids must be < num_segments")); + } else { + output_rows = last_segment_id_plus_one; + } OP_REQUIRES(context, output_rows >= 0, errors::InvalidArgument("segment ids must be >= 0")); @@ -646,11 +667,20 @@ class SparseSegmentReductionOpBase : public OpKernel { indices_vec(start + bad_offset), " out of range [0, ", input_flat.dimension(0), ")")); - if (end >= num_indices) break; start = end; ++end; uninitialized_index = out_index + 1; out_index = next_index; + if (end > num_indices) break; + } + + // Fill the gap at the end with the default value. + if (uninitialized_index < output_rows) { + Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( + output_rows - uninitialized_index, num_col); + Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> + gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); + gap_slice.setConstant(default_value_); } } @@ -786,6 +816,7 @@ class SparseSegmentReductionOpBase : public OpKernel { const bool is_mean_; const bool is_sqrtn_; + const bool has_num_segments_; const T default_value_; }; @@ -794,9 +825,20 @@ class SparseSegmentReductionMeanOp : public SparseSegmentReductionOpBase<Device, T> { public: explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) - : SparseSegmentReductionOpBase<Device, T>(context, true /*is_mean*/, - false /*is_sqrtn*/, - T(0) /* default_value */) {} + : SparseSegmentReductionOpBase<Device, T>( + context, true /*is_mean*/, false /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} +}; + +template <typename Device, class T> +class SparseSegmentReductionMeanWithNumSegmentsOp + : public SparseSegmentReductionOpBase<Device, T> { + public: + explicit SparseSegmentReductionMeanWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase<Device, T>( + context, true /*is_mean*/, false /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} }; template <typename Device, class T> @@ -804,9 +846,20 @@ class SparseSegmentReductionSqrtNOp : public SparseSegmentReductionOpBase<Device, T> { public: explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) - : SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/, - true /*is_sqrtn*/, - T(0) /* default_value */) {} + : SparseSegmentReductionOpBase<Device, T>( + context, false /*is_mean*/, true /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} +}; + +template <typename Device, class T> +class SparseSegmentReductionSqrtNWithNumSegmentsOp + : public SparseSegmentReductionOpBase<Device, T> { + public: + explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase<Device, T>( + context, false /*is_mean*/, true /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} }; template <typename Device, class T> @@ -814,37 +867,65 @@ class SparseSegmentReductionSumOp : public SparseSegmentReductionOpBase<Device, T> { public: explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) - : SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/, - false /*is_sqrtn*/, - T(0) /* default_value */) {} + : SparseSegmentReductionOpBase<Device, T>( + context, false /*is_mean*/, false /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} }; -#define REGISTER_CPU_SPARSE_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<int32>("Tidx"), \ - SparseSegmentReductionSumOp<CPUDevice, type>); +template <typename Device, class T> +class SparseSegmentReductionSumWithNumSegmentsOp + : public SparseSegmentReductionOpBase<Device, T> { + public: + explicit SparseSegmentReductionSumWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase<Device, T>( + context, false /*is_mean*/, false /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} +}; +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionSumOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSegmentSumWithNumSegments") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS); #undef REGISTER_CPU_SPARSE_KERNELS -#define REGISTER_CPU_SPARSE_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<int32>("Tidx"), \ - SparseSegmentReductionMeanOp<CPUDevice, type>); +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionMeanOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSegmentMeanWithNumSegments") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>); REGISTER_CPU_SPARSE_KERNELS(float); REGISTER_CPU_SPARSE_KERNELS(double); #undef REGISTER_CPU_SPARSE_KERNELS -#define REGISTER_CPU_SPARSE_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<int32>("Tidx"), \ - SparseSegmentReductionSqrtNOp<CPUDevice, type>); +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionSqrtNOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSegmentSqrtNWithNumSegments") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx"), \ + SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>); REGISTER_CPU_SPARSE_KERNELS(float); REGISTER_CPU_SPARSE_KERNELS(double); #undef REGISTER_CPU_SPARSE_KERNELS @@ -889,9 +970,10 @@ class SparseSegmentGradOpBase : public OpKernel { // Note that similar to SparseSegmentMean, we assume that segment_vec is // already sorted and has non-negative values. - const SegmentId num_segments = + const SegmentId num_segments = input.dim_size(0); + const SegmentId last_segment_id_plus_one = internal::SubtleMustCopy(segment_vec(N - 1)) + 1; - OP_REQUIRES(context, input.dim_size(0) == num_segments, + OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, errors::InvalidArgument("Invalid number of segments")); // Compute scaling factors for input. diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 45ebfa203b..8ea170ba14 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1632,6 +1632,45 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { return Status::OK(); } +Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { + ShapeHandle data_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); + + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); + + ShapeHandle segment_ids_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); + + ShapeHandle num_segments_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape)); + + // indices and segment_ids should merge cleanly. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); + + ShapeHandle subshape; + TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); + + ShapeHandle out; + const Tensor* dim0 = c->input_tensor(3); + if (dim0 == nullptr) { + // We don't have the value at inference time, so the output + // shape is unknown. + TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim), + subshape, &out)); + } else { + auto dim0_value = dim0->scalar<int32>()(); + if (dim0_value < 0) { + return errors::InvalidArgument( + "Cannot specify a negative value for num_segments"); + } + TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out)); + } + c->set_output(0, out); + return Status::OK(); +} + Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { ShapeHandle s_data = c->input(0); ShapeHandle s_segment_ids = c->input(1); @@ -1890,6 +1929,7 @@ output: Has same shape as data, except for dimension 0 which has size `num_segments`. )doc"); + REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") @@ -1938,6 +1978,56 @@ output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); +REGISTER_OP("SparseSegmentSumWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the sum along sparse segments of a tensor. + +Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +For example: + +```python +c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + +tf.sparse_segment_sum_with_num_segments( + c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +# => [[0 0 0 0] +# [0 0 0 0] +# [0 0 0 0]] + +tf.sparse_segment_sum_with_num_segments(c, + tf.constant([0, 1]), + tf.constant([0, 2], + num_segments=4)) +# => [[ 1 2 3 4] +# [ 0 0 0 0] +# [-1 -2 -3 -4] +# [ 0 0 0 0]] +``` + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which + has size `num_segments`. +)doc"); + REGISTER_OP("SparseSegmentMean") .Input("data: T") .Input("indices: Tidx") @@ -1964,6 +2054,35 @@ output: Has same shape as data, except for dimension 0 which )doc"); +REGISTER_OP("SparseSegmentMeanWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the mean along sparse segments of a tensor. + +Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which has size + `num_segments`. +)doc"); + REGISTER_OP("SparseSegmentMeanGrad") .Input("grad: T") .Input("indices: Tidx") @@ -2010,6 +2129,38 @@ output: Has same shape as data, except for dimension 0 which )doc"); +REGISTER_OP("SparseSegmentSqrtNWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the sum along sparse segments of a tensor divided by the sqrt of N. + +N is the size of the segment being reduced. + +Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which + has size `k`, the number of segments. + +)doc"); + REGISTER_OP("SparseSegmentSqrtNGrad") .Input("grad: T") .Input("indices: Tidx") diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index fd58cdb170..5a54f448d0 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -46,13 +46,13 @@ class SegmentReductionHelper(test.TestCase): return constant_op.constant( np_values, shape=input_shape, dtype=dtype), np_values - def _segmentReduce(self, indices, x, op1, op2=None, num_out_rows=None): + def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None): if not x.size: return np.array([]) indices = np.asarray(indices) - if num_out_rows is None: - num_out_rows = indices[-1] + 1 - output = [None] * num_out_rows + if num_segments is None: + num_segments = indices[-1] + 1 + output = [None] * num_segments slice_shape = x.shape[indices.ndim:] x_flat = x.reshape((indices.size,) + slice_shape) for i, index in enumerate(indices.ravel()): @@ -259,7 +259,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) np_ans = self._segmentReduce( - indices, np_x, np.add, op2=None, num_out_rows=num_segments) + indices, np_x, np.add, op2=None, num_segments=num_segments) s = math_ops.unsorted_segment_sum( data=tf_x, segment_ids=indices, num_segments=num_segments) tf_ans = s.eval() @@ -278,7 +278,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): num_segments_constant = constant_op.constant( num_segments, dtype=dtype) np_ans = self._segmentReduce( - indices, np_x, np.add, op2=None, num_out_rows=num_segments) + indices, np_x, np.add, op2=None, num_segments=num_segments) s = math_ops.unsorted_segment_sum( data=tf_x, segment_ids=indices, @@ -397,7 +397,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) np_ans = self._segmentReduce( - indices, np_x, np.add, op2=None, num_out_rows=num_segments) + indices, np_x, np.add, op2=None, num_segments=num_segments) # Replace np_ans[8] with 0 for the value np_ans[8:] = 0 # Replace 8 with -1 in indices @@ -417,8 +417,15 @@ class SparseSegmentReductionHelper(SegmentReductionHelper): return (constant_op.constant( indices, dtype=dtypes_lib.int32), indices, a, b) - def _sparseSegmentReduce(self, x, indices, segment_indices, op1, op2=None): - return self._segmentReduce(segment_indices, x[indices], op1, op2) + def _sparseSegmentReduce(self, + x, + indices, + segment_indices, + op1, + op2=None, + num_segments=None): + return self._segmentReduce( + segment_indices, x[indices], op1, op2, num_segments=num_segments) class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): @@ -475,6 +482,31 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): tf_ans = s.eval() self.assertAllClose(np_ans, tf_ans) + def testWithNumSegments(self): + tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) + ops_list = [(np.add, None, math_ops.sparse_segment_sum_with_num_segments), + (self._mean_cum_op, self._mean_reduce_op, + math_ops.sparse_segment_mean_with_num_segments)] + segment_indices = [0, 2, 2, 2] + tf_indices = [8, 3, 0, 9] + num_segments = 5 + with self.test_session(use_gpu=False): + for np_op1, np_op2, tf_op in ops_list: + np_ans = self._sparseSegmentReduce( + np_x, + tf_indices, + segment_indices, + np_op1, + np_op2, + num_segments=num_segments) + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) + def testSegmentIdsGreaterThanZero(self): tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) ops_list = [(np.add, None, math_ops.sparse_segment_sum), ( @@ -583,6 +615,63 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): with self.assertRaisesOpError("segment ids must be >= 0"): s.eval() + def testSegmentWithNumSegmentsValid(self): + # Baseline for the test*WithNumSegmentsInvalid* methods below. + tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) + ops_list = [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments, + ] + num_segments = 5 + segment_indices = [0, 1, 3, 3] + tf_indices = [8, 3, 0, 9] + with self.test_session(use_gpu=False): + for tf_op in ops_list: + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + s.eval() + + def testSegmentWithNumSegmentsInvalid1(self): + tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) + ops_list = [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments, + ] + num_segments = 5 + segment_indices = [0, 1, 3, 5] + tf_indices = [8, 3, 0, 9] + with self.test_session(use_gpu=False): + for tf_op in ops_list: + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + with self.assertRaisesOpError("segment ids must be < num_segments"): + s.eval() + + def testSegmentWithNumSegmentsInvalid2(self): + tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32) + ops_list = [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments, + ] + num_segments = -2 + segment_indices = [0, 1, 3, 3] + tf_indices = [8, 3, 0, 9] + with self.test_session(use_gpu=False): + for tf_op in ops_list: + with self.assertRaisesRegexp( + ValueError, "Cannot specify a negative value for num_segments"): + tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + def testGradient(self): shape = [10, 4] @@ -601,6 +690,32 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): delta=1) self.assertAllClose(jacob_t, jacob_n) + def testGradientWithEmptySegmentsAtEnd(self): + shape = [10, 4] + + num_segments = 5 + segment_indices = [0, 1, 2, 2] + num_indices = len(segment_indices) + for tf_op in [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments, + ]: + with self.test_session(): + tf_indices, _, tf_x, np_x = self._sparse_input( + shape, num_indices, dtype=dtypes_lib.float64) + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + jacob_t, jacob_n = gradient_checker.compute_gradient( + tf_x, + shape, + s, [5, 4], + x_init_value=np_x.astype(np.double), + delta=1) + self.assertAllClose(jacob_t, jacob_n) + def testGradientValid(self): # Baseline for the testGradient*Invalid* methods below. tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32) @@ -646,7 +761,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): ops_list = [ math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad ] - segment_indices = [0, 1, 1, 1] # 2 segments + segment_indices = [0, 1, 1, 4] # 5 segments tf_indices = [8, 3, 0, 9] with self.test_session(use_gpu=False): for tf_op in ops_list: diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 38fe093ba7..0239396ae3 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -184,6 +184,15 @@ def _SparseSegmentSumGrad(op, grad): None) +@ops.RegisterGradient("SparseSegmentSumWithNumSegments") +def _SparseSegmentSumWithNumSegmentsGrad(op, grad): + """Gradient for SparseSegmentSumWithNumSegments.""" + input_rows = array_ops.shape(op.inputs[0])[0] + return (math_ops.unsorted_segment_sum( + array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, + None, None) + + @ops.RegisterGradient("SparseSegmentMean") def _SparseSegmentMeanGrad(op, grad): """Gradient for SparseSegmentMean.""" @@ -192,6 +201,14 @@ def _SparseSegmentMeanGrad(op, grad): dim0), None, None) +@ops.RegisterGradient("SparseSegmentMeanWithNumSegments") +def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): + """Gradient for SparseSegmentMeanWithNumSegments.""" + dim0 = array_ops.shape(op.inputs[0])[0] + return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], + dim0), None, None, None) + + @ops.RegisterGradient("SparseSegmentSqrtN") def _SparseSegmentSqrtNGrad(op, grad): """Gradient for SparseSegmentSqrtN.""" @@ -200,6 +217,14 @@ def _SparseSegmentSqrtNGrad(op, grad): dim0), None, None) +@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") +def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): + """Gradient for SparseSegmentSqrtNWithNumSegmnets.""" + dim0 = array_ops.shape(op.inputs[0])[0] + return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], + dim0), None, None, None) + + def _SegmentMinOrMaxGrad(op, grad, is_sorted): """Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code.""" zeros = array_ops.zeros(array_ops.shape(op.inputs[0]), diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index f9538be6c9..6af36343d5 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2495,6 +2495,159 @@ def reduced_shape(input_shape, axes): ]) # [1, 1] +def sparse_segment_sum(data, indices, segment_ids, name=None, + num_segments=None): + r"""Computes the sum along sparse segments of a tensor. + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first + dimension, selecting a subset of dimension 0, specified by `indices`. + `segment_ids` is allowed to have missing ids, in which case the output will + be zeros at those indices. In those cases `num_segments` is used to determine + the size of the output. + + For example: + + ```python + c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + + # Select two rows, one segment. + tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) + # => [[0 0 0 0]] + + # Select two rows, two segment. + tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) + # => [[ 1 2 3 4] + # [-1 -2 -3 -4]] + + # With missing segment ids. + tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]), + num_segments=4) + # => [[ 1 2 3 4] + # [ 0 0 0 0] + # [-1 -2 -3 -4] + # [ 0 0 0 0]] + + # Select all rows, two segments. + tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) + # => [[0 0 0 0] + # [5 6 7 8]] + + # Which is equivalent to: + tf.segment_sum(c, tf.constant([0, 0, 1])) + ``` + + Args: + data: A `Tensor` with data that will be assembled in the output. + indices: A 1-D `Tensor` with indices into `data`. Has same rank as + `segment_ids`. + segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. + Values should be sorted and can be repeated. + name: A name for the operation (optional). + num_segments: An optional int32 scalar. Indicates the size of the output + `Tensor`. + + Returns: + A `tensor` of the shape as data, except for dimension 0 which + has size `k`, the number of segments specified via `num_segments` or + inferred for the last element in `segments_ids`. + """ + if num_segments is not None: + return gen_math_ops.sparse_segment_sum_with_num_segments( + data=data, + indices=indices, + segment_ids=segment_ids, + num_segments=num_segments, + name=name) + else: + return gen_math_ops.sparse_segment_sum( + data=data, + indices=indices, + segment_ids=segment_ids, + name=name) + + +def sparse_segment_mean(data, indices, segment_ids, name=None, + num_segments=None): + r"""Computes the mean along sparse segments of a tensor. + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first + dimension, selecting a subset of dimension 0, specified by `indices`. + `segment_ids` is allowed to have missing ids, in which case the output will + be zeros at those indices. In those cases `num_segments` is used to determine + the size of the output. + + Args: + data: A `Tensor` with data that will be assembled in the output. + indices: A 1-D `Tensor` with indices into `data`. Has same rank as + `segment_ids`. + segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. + Values should be sorted and can be repeated. + name: A name for the operation (optional). + num_segments: An optional int32 scalar. Indicates the size of the output + `Tensor`. + + Returns: + A `tensor` of the shape as data, except for dimension 0 which + has size `k`, the number of segments specified via `num_segments` or + inferred for the last element in `segments_ids`. + """ + if num_segments is not None: + return gen_math_ops.sparse_segment_mean_with_num_segments( + data=data, + indices=indices, + segment_ids=segment_ids, + num_segments=num_segments, + name=name) + else: + return gen_math_ops.sparse_segment_mean( + data=data, + indices=indices, + segment_ids=segment_ids, + name=name) + + +def sparse_segment_sqrt_n(data, indices, segment_ids, name=None, + num_segments=None): + r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N). + + `N` is the size of the segment being reduced. + + Args: + data: A `Tensor` with data that will be assembled in the output. + indices: A 1-D `Tensor` with indices into `data`. Has same rank as + `segment_ids`. + segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. + Values should be sorted and can be repeated. + name: A name for the operation (optional). + num_segments: An optional int32 scalar. Indicates the size of the output + `Tensor`. + + Returns: + A `tensor` of the shape as data, except for dimension 0 which + has size `k`, the number of segments specified via `num_segments` or + inferred for the last element in `segments_ids`. + """ + if num_segments is not None: + return gen_math_ops.sparse_segment_sqrt_n_with_num_segments( + data=data, + indices=indices, + segment_ids=segment_ids, + num_segments=num_segments, + name=name) + else: + return gen_math_ops.sparse_segment_sqrt_n( + data=data, + indices=indices, + segment_ids=segment_ids, + name=name) + + def tensordot(a, b, axes, name=None): r"""Tensor contraction of a and b along specified axes. diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index b12cf5a864..4b33aa218c 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1842,15 +1842,15 @@ tf_module { } member_method { name: "sparse_segment_mean" - argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "sparse_segment_sqrt_n" - argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "sparse_segment_sum" - argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "sparse_slice" |