diff options
author | Phil <ijund.phil@gmail.com> | 2018-02-07 19:59:59 +0100 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-02-07 10:59:59 -0800 |
commit | 3d86d8ce14989ca65a59ad4cf37f690694bf6267 (patch) | |
tree | ae2797cd796b292f8303bb58dae23c80489d4749 /tensorflow/core/ops/math_ops.cc | |
parent | 8aa14cd682053e1e643f0a74ec25cf3b87bf2712 (diff) |
Add unsortedsegment(prod/min/max/sqrt_n/mean). (#15858)
* Add unsortedsegment(prod/min/max/sqrt_n/mean).
This commit adds CPU/GPU implementations for prod/min/max
ops and python implementations for mean/sqrt_n. Also, it adapts and unifies the
corresponding tests of all unsorted reductions.
Note: The new gradient of unsorted_segment_max fixes the crash occuring when
negative indices on CPU are used.
* update golden API
* Fix compilation of atomicAdd for cuda_arch < 600. \n This commit moves the std::complex specialization of atomicAdd below the double specialization of atomicAdd for cuda_arch 600.
* Enable bfloat16, change inline to EIGEN_STRONG_INLINE.
* fix includes of cuda_device_functions; fix typo
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 872ebe98c1..8f33d51d5a 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1065,6 +1065,26 @@ REGISTER_OP("UnsortedSegmentMax") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn); +REGISTER_OP("UnsortedSegmentMin") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + +REGISTER_OP("UnsortedSegmentProd") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") |