aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/segment_reduction_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.cc')
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc113
1 files changed, 93 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index fee16cdb78..5bd4362801 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
namespace functor {
// UnsortedSegmentSumFunctor implementation for CPUDevice.
+// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor.
template <typename T, typename Index>
-struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
+struct UnsortedSegmentSumFunctor<CPUDevice, T, Index>
+ : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
- typename TTypes<T, 2>::Tensor output) {
+ typename TTypes<T, 2>::Tensor output) override {
output.setZero();
if (data_size == 0) {
return;
@@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
}
}
};
-
+// UnsortedSegmentMaxFunctor implementation for CPUDevice.
+template <typename T, typename Index>
+struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index>
+ : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
+ void operator()(OpKernelContext* ctx, const CPUDevice& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output) override {
+ output.setConstant(std::numeric_limits<T>::min());
+ if (data_size == 0) {
+ return;
+ }
+ const int64 N = segment_ids.dimension(0);
+ auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
+ for (int64 i = 0; i < N; ++i) {
+ Index j = internal::SubtleMustCopy(segment_ids(i));
+ OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows),
+ errors::InvalidArgument(
+ "segment_ids", SliceDebugString(segment_ids_shape, i),
+ " = ", j, " is out of range [0, ", output_rows, ")"));
+ output.template chip<0>(j) =
+ data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j));
+ }
+ }
+};
} // namespace functor
-// Similar to SegmentReductionOp but can handle unsorted segment definitions and
-// specifying size of output.
+// Base class for SegmentReductionOps that can handle unsorted segment
+// definitions
+// and specifying the size of the output in addition to a reduction function
template <typename Device, class T, class Index>
-class UnsortedSegmentSumOp : public OpKernel {
+class UnsortedSegmentBaseOp : public OpKernel {
public:
- explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
- : OpKernel(context) {}
+ explicit UnsortedSegmentBaseOp(
+ OpKernelConstruction* context,
+ functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor)
+ : OpKernel(context), reduction_functor_(functor) {}
void Compute(OpKernelContext* context) override {
const Tensor& data = context->input(0);
@@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel {
auto output_flat = output->flat_outer_dims<T>();
auto data_ptr = data.template flat<T>().data();
- functor::UnsortedSegmentSumFunctor<Device, T, Index>()(
- context, context->template eigen_device<Device>(), output_rows,
- segment_ids.shape(), segment_flat, data.NumElements(), data_ptr,
- output_flat);
+ reduction_functor_(context, context->template eigen_device<Device>(),
+ output_rows, segment_ids.shape(), segment_flat,
+ data.NumElements(), data_ptr, output_flat);
}
+ private:
+ functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_;
};
-#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \
+template <typename Device, class T, class Index>
+class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> {
+ public:
+ explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
+ : UnsortedSegmentBaseOp<Device, T, Index>(
+ context,
+ sum_functor_) {}
+ private:
+ functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_;
+};
+
+template <typename Device, class T, class Index>
+class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> {
+ public:
+ explicit UnsortedSegmentMaxOp(OpKernelConstruction* context)
+ : UnsortedSegmentBaseOp<Device, T, Index>(
+ context,
+ max_functor_) {}
+ private:
+ functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_;
+};
+
+#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ UnsortedSegmentSumOp<CPUDevice, type, index_type>); \
+ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ UnsortedSegmentMaxOp<CPUDevice, type, index_type>);
+
+#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
UnsortedSegmentSumOp<CPUDevice, type, index_type>);
-#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \
- REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
- REGISTER_CPU_UNSORTED_KERNELS(type, int64);
-
-TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
-#undef REGISTER_CPU_UNSORTED_KERNELS
-#undef REGISTER_CPU_UNSORTED_KERNELS_ALL
+#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
+ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \
+ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
+
+#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
+ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \
+ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
+REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
+REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
+#undef REGISTER_REAL_CPU_UNSORTED_KERNELS
+#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
+#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
+#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
#if GOOGLE_CUDA
#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \