diff options
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 45ebfa203b..8ea170ba14 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1632,6 +1632,45 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { return Status::OK(); } +Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { + ShapeHandle data_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); + + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); + + ShapeHandle segment_ids_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); + + ShapeHandle num_segments_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape)); + + // indices and segment_ids should merge cleanly. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); + + ShapeHandle subshape; + TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); + + ShapeHandle out; + const Tensor* dim0 = c->input_tensor(3); + if (dim0 == nullptr) { + // We don't have the value at inference time, so the output + // shape is unknown. + TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim), + subshape, &out)); + } else { + auto dim0_value = dim0->scalar<int32>()(); + if (dim0_value < 0) { + return errors::InvalidArgument( + "Cannot specify a negative value for num_segments"); + } + TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out)); + } + c->set_output(0, out); + return Status::OK(); +} + Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { ShapeHandle s_data = c->input(0); ShapeHandle s_segment_ids = c->input(1); @@ -1890,6 +1929,7 @@ output: Has same shape as data, except for dimension 0 which has size `num_segments`. )doc"); + REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") @@ -1938,6 +1978,56 @@ output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); +REGISTER_OP("SparseSegmentSumWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the sum along sparse segments of a tensor. + +Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +For example: + +```python +c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + +tf.sparse_segment_sum_with_num_segments( + c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +# => [[0 0 0 0] +# [0 0 0 0] +# [0 0 0 0]] + +tf.sparse_segment_sum_with_num_segments(c, + tf.constant([0, 1]), + tf.constant([0, 2], + num_segments=4)) +# => [[ 1 2 3 4] +# [ 0 0 0 0] +# [-1 -2 -3 -4] +# [ 0 0 0 0]] +``` + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which + has size `num_segments`. +)doc"); + REGISTER_OP("SparseSegmentMean") .Input("data: T") .Input("indices: Tidx") @@ -1964,6 +2054,35 @@ output: Has same shape as data, except for dimension 0 which )doc"); +REGISTER_OP("SparseSegmentMeanWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the mean along sparse segments of a tensor. + +Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which has size + `num_segments`. +)doc"); + REGISTER_OP("SparseSegmentMeanGrad") .Input("grad: T") .Input("indices: Tidx") @@ -2010,6 +2129,38 @@ output: Has same shape as data, except for dimension 0 which )doc"); +REGISTER_OP("SparseSegmentSqrtNWithNumSegments") + .Input("data: T") + .Input("indices: Tidx") + .Input("segment_ids: int32") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn) + .Doc(R"doc( +Computes the sum along sparse segments of a tensor divided by the sqrt of N. + +N is the size of the segment being reduced. + +Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +misisng, the `output` tensor at that position will be zeroed. + +Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +segments. + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +num_segments: Should equal the number of distinct segment IDs. + +output: Has same shape as data, except for dimension 0 which + has size `k`, the number of segments. + +)doc"); + REGISTER_OP("SparseSegmentSqrtNGrad") .Input("grad: T") .Input("indices: Tidx") |