aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_ops.cc
diff options
context:
space:
mode:
authorGravatar Phil <ijund.phil@gmail.com>2018-02-07 19:59:59 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-02-07 10:59:59 -0800
commit3d86d8ce14989ca65a59ad4cf37f690694bf6267 (patch)
treeae2797cd796b292f8303bb58dae23c80489d4749 /tensorflow/core/ops/math_ops.cc
parent8aa14cd682053e1e643f0a74ec25cf3b87bf2712 (diff)
Add unsortedsegment(prod/min/max/sqrt_n/mean). (#15858)
* Add unsortedsegment(prod/min/max/sqrt_n/mean). This commit adds CPU/GPU implementations for prod/min/max ops and python implementations for mean/sqrt_n. Also, it adapts and unifies the corresponding tests of all unsorted reductions. Note: The new gradient of unsorted_segment_max fixes the crash occuring when negative indices on CPU are used. * update golden API * Fix compilation of atomicAdd for cuda_arch < 600. \n This commit moves the std::complex specialization of atomicAdd below the double specialization of atomicAdd for cuda_arch 600. * Enable bfloat16, change inline to EIGEN_STRONG_INLINE. * fix includes of cuda_device_functions; fix typo
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r--tensorflow/core/ops/math_ops.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 872ebe98c1..8f33d51d5a 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1065,6 +1065,26 @@ REGISTER_OP("UnsortedSegmentMax")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn);
+REGISTER_OP("UnsortedSegmentMin")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(UnsortedSegmentReductionShapeFn);
+
+REGISTER_OP("UnsortedSegmentProd")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(UnsortedSegmentReductionShapeFn);
+
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")