aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r--tensorflow/core/ops/math_ops.cc151
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 45ebfa203b..8ea170ba14 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1632,6 +1632,45 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
+ ShapeHandle data_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
+
+ ShapeHandle indices_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
+
+ ShapeHandle segment_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
+
+ ShapeHandle num_segments_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
+
+ // indices and segment_ids should merge cleanly.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
+
+ ShapeHandle subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
+
+ ShapeHandle out;
+ const Tensor* dim0 = c->input_tensor(3);
+ if (dim0 == nullptr) {
+ // We don't have the value at inference time, so the output
+ // shape is unknown.
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
+ subshape, &out));
+ } else {
+ auto dim0_value = dim0->scalar<int32>()();
+ if (dim0_value < 0) {
+ return errors::InvalidArgument(
+ "Cannot specify a negative value for num_segments");
+ }
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
+ }
+ c->set_output(0, out);
+ return Status::OK();
+}
+
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
ShapeHandle s_data = c->input(0);
ShapeHandle s_segment_ids = c->input(1);
@@ -1890,6 +1929,7 @@ output: Has same shape as data, except for dimension 0 which
has size `num_segments`.
)doc");
+
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")
@@ -1938,6 +1978,56 @@ output: Has same shape as data, except for dimension 0 which
has size `k`, the number of segments.
)doc");
+REGISTER_OP("SparseSegmentSumWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor.
+
+Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+tf.sparse_segment_sum_with_num_segments(
+ c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+# => [[0 0 0 0]
+# [0 0 0 0]
+# [0 0 0 0]]
+
+tf.sparse_segment_sum_with_num_segments(c,
+ tf.constant([0, 1]),
+ tf.constant([0, 2],
+ num_segments=4))
+# => [[ 1 2 3 4]
+# [ 0 0 0 0]
+# [-1 -2 -3 -4]
+# [ 0 0 0 0]]
+```
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which
+ has size `num_segments`.
+)doc");
+
REGISTER_OP("SparseSegmentMean")
.Input("data: T")
.Input("indices: Tidx")
@@ -1964,6 +2054,35 @@ output: Has same shape as data, except for dimension 0 which
)doc");
+REGISTER_OP("SparseSegmentMeanWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the mean along sparse segments of a tensor.
+
+Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which has size
+ `num_segments`.
+)doc");
+
REGISTER_OP("SparseSegmentMeanGrad")
.Input("grad: T")
.Input("indices: Tidx")
@@ -2010,6 +2129,38 @@ output: Has same shape as data, except for dimension 0 which
)doc");
+REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor divided by the sqrt of N.
+
+N is the size of the segment being reduced.
+
+Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which
+ has size `k`, the number of segments.
+
+)doc");
+
REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("grad: T")
.Input("indices: Tidx")