diff options
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.cc | 113 |
1 files changed, 93 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index fee16cdb78..5bd4362801 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); namespace functor { // UnsortedSegmentSumFunctor implementation for CPUDevice. +// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor. template <typename T, typename Index> -struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> { +struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> + : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> { void operator()(OpKernelContext* ctx, const CPUDevice& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes<T, 2>::Tensor output) { + typename TTypes<T, 2>::Tensor output) override { output.setZero(); if (data_size == 0) { return; @@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> { } } }; - +// UnsortedSegmentMaxFunctor implementation for CPUDevice. +template <typename T, typename Index> +struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index> + : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> { + void operator()(OpKernelContext* ctx, const CPUDevice& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output) override { + output.setConstant(std::numeric_limits<T>::min()); + if (data_size == 0) { + return; + } + const int64 N = segment_ids.dimension(0); + auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N); + for (int64 i = 0; i < N; ++i) { + Index j = internal::SubtleMustCopy(segment_ids(i)); + OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), + errors::InvalidArgument( + "segment_ids", SliceDebugString(segment_ids_shape, i), + " = ", j, " is out of range [0, ", output_rows, ")")); + output.template chip<0>(j) = + data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j)); + } + } +}; } // namespace functor -// Similar to SegmentReductionOp but can handle unsorted segment definitions and -// specifying size of output. +// Base class for SegmentReductionOps that can handle unsorted segment +// definitions +// and specifying the size of the output in addition to a reduction function template <typename Device, class T, class Index> -class UnsortedSegmentSumOp : public OpKernel { +class UnsortedSegmentBaseOp : public OpKernel { public: - explicit UnsortedSegmentSumOp(OpKernelConstruction* context) - : OpKernel(context) {} + explicit UnsortedSegmentBaseOp( + OpKernelConstruction* context, + functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor) + : OpKernel(context), reduction_functor_(functor) {} void Compute(OpKernelContext* context) override { const Tensor& data = context->input(0); @@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel { auto output_flat = output->flat_outer_dims<T>(); auto data_ptr = data.template flat<T>().data(); - functor::UnsortedSegmentSumFunctor<Device, T, Index>()( - context, context->template eigen_device<Device>(), output_rows, - segment_ids.shape(), segment_flat, data.NumElements(), data_ptr, - output_flat); + reduction_functor_(context, context->template eigen_device<Device>(), + output_rows, segment_ids.shape(), segment_flat, + data.NumElements(), data_ptr, output_flat); } + private: + functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_; }; -#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \ +template <typename Device, class T, class Index> +class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> { + public: + explicit UnsortedSegmentSumOp(OpKernelConstruction* context) + : UnsortedSegmentBaseOp<Device, T, Index>( + context, + sum_functor_) {} + private: + functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_; +}; + +template <typename Device, class T, class Index> +class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> { + public: + explicit UnsortedSegmentMaxOp(OpKernelConstruction* context) + : UnsortedSegmentBaseOp<Device, T, Index>( + context, + max_functor_) {} + private: + functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_; +}; + +#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + UnsortedSegmentSumOp<CPUDevice, type, index_type>); \ + REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + UnsortedSegmentMaxOp<CPUDevice, type, index_type>); + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<index_type>("Tindices"), \ UnsortedSegmentSumOp<CPUDevice, type, index_type>); -#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \ - REGISTER_CPU_UNSORTED_KERNELS(type, int32); \ - REGISTER_CPU_UNSORTED_KERNELS(type, int64); - -TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL); -#undef REGISTER_CPU_UNSORTED_KERNELS -#undef REGISTER_CPU_UNSORTED_KERNELS_ALL +#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64) + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); +REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); +REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); +#undef REGISTER_REAL_CPU_UNSORTED_KERNELS +#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS +#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL +#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL #if GOOGLE_CUDA #define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \ |