diff options
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 98 |
1 files changed, 68 insertions, 30 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 92e2823fb2..00876bc18c 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { return Status::OK(); } +Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { + ShapeHandle s_data = c->input(0); + ShapeHandle s_segment_ids = c->input(1); + ShapeHandle s_num_segments = c->input(2); + TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); + + ShapeHandle out; + + // Leading dimensions of data must be compatible with dimensions of + // <s_segment_ids>. + if (c->RankKnown(s_segment_ids)) { + TF_RETURN_IF_ERROR( + c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); + + // Get the value of the num_segments input tensor. + DimensionHandle num_segments_dim; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); + + // Output is {segment_id_rank} + s_data[segment_id_rank:]. + ShapeHandle s_data_suffix; + TF_RETURN_IF_ERROR( + c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); + TF_RETURN_IF_ERROR( + c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); + } else { + out = c->UnknownShape(); + } + c->set_output(0, out); + return Status::OK(); +} } // namespace REGISTER_OP("SegmentSum") @@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum") .Output("output: T") .Attr("T: numbertype") .Attr("Tindices: {int32,int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s_data = c->input(0); - ShapeHandle s_segment_ids = c->input(1); - ShapeHandle s_num_segments = c->input(2); - TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); - - ShapeHandle out; - - // Leading dimensions of data must be compatible with dimensions of - // <s_segment_ids>. - if (c->RankKnown(s_segment_ids)) { - TF_RETURN_IF_ERROR( - c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); - - // Get the value of the num_segments input tensor. - DimensionHandle num_segments_dim; - TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); - - // Output is {segment_id_rank} + s_data[segment_id_rank:]. - ShapeHandle s_data_suffix; - TF_RETURN_IF_ERROR( - c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); - TF_RETURN_IF_ERROR( - c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); - } else { - out = c->UnknownShape(); - } - c->set_output(0, out); - return Status::OK(); - }) + .SetShapeFn(UnsortedSegmentReductionShapeFn) .Doc(R"doc( Computes the sum along segments of a tensor. @@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank` )doc"); + +REGISTER_OP("UnsortedSegmentMax") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .SetShapeFn(UnsortedSegmentReductionShapeFn) + .Doc(R"doc( +Computes the Max along segments of a tensor. + +Read [the section on +Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation +of segments. + +This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum). +Instead of computing the sum over segments, it computes the maximum +such that: + +\\(output_i = \max_j data_j\\) where max is over `j` such +that `segment_ids[j] == i`. + +If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type, + `output[i] = numeric_limits<T>::min()`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt> +</div> + +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. + +output: Has same shape as data, except for dimension 0 which +has size `num_segments`. + +)doc"); REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") |