aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r--tensorflow/core/ops/math_ops.cc98
1 files changed, 68 insertions, 30 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 92e2823fb2..00876bc18c 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
+ ShapeHandle s_data = c->input(0);
+ ShapeHandle s_segment_ids = c->input(1);
+ ShapeHandle s_num_segments = c->input(2);
+ TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
+
+ ShapeHandle out;
+
+ // Leading dimensions of data must be compatible with dimensions of
+ // <s_segment_ids>.
+ if (c->RankKnown(s_segment_ids)) {
+ TF_RETURN_IF_ERROR(
+ c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
+
+ // Get the value of the num_segments input tensor.
+ DimensionHandle num_segments_dim;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
+
+ // Output is {segment_id_rank} + s_data[segment_id_rank:].
+ ShapeHandle s_data_suffix;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
+ } else {
+ out = c->UnknownShape();
+ }
+ c->set_output(0, out);
+ return Status::OK();
+}
} // namespace
REGISTER_OP("SegmentSum")
@@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum")
.Output("output: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle s_data = c->input(0);
- ShapeHandle s_segment_ids = c->input(1);
- ShapeHandle s_num_segments = c->input(2);
- TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
-
- ShapeHandle out;
-
- // Leading dimensions of data must be compatible with dimensions of
- // <s_segment_ids>.
- if (c->RankKnown(s_segment_ids)) {
- TF_RETURN_IF_ERROR(
- c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
-
- // Get the value of the num_segments input tensor.
- DimensionHandle num_segments_dim;
- TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
-
- // Output is {segment_id_rank} + s_data[segment_id_rank:].
- ShapeHandle s_data_suffix;
- TF_RETURN_IF_ERROR(
- c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
- TF_RETURN_IF_ERROR(
- c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
- } else {
- out = c->UnknownShape();
- }
- c->set_output(0, out);
- return Status::OK();
- })
+ .SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along segments of a tensor.
@@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank`
)doc");
+
+REGISTER_OP("UnsortedSegmentMax")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(UnsortedSegmentReductionShapeFn)
+ .Doc(R"doc(
+Computes the Max along segments of a tensor.
+
+Read [the section on
+Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
+of segments.
+
+This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+Instead of computing the sum over segments, it computes the maximum
+such that:
+
+\\(output_i = \max_j data_j\\) where max is over `j` such
+that `segment_ids[j] == i`.
+
+If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type,
+ `output[i] = numeric_limits<T>::min()`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension.
+
+output: Has same shape as data, except for dimension 0 which
+has size `num_segments`.
+
+)doc");
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")