diff options
author | Phil <ijund.phil@gmail.com> | 2018-02-07 19:59:59 +0100 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-02-07 10:59:59 -0800 |
commit | 3d86d8ce14989ca65a59ad4cf37f690694bf6267 (patch) | |
tree | ae2797cd796b292f8303bb58dae23c80489d4749 /tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc | |
parent | 8aa14cd682053e1e643f0a74ec25cf3b87bf2712 (diff) |
Add unsortedsegment(prod/min/max/sqrt_n/mean). (#15858)
* Add unsortedsegment(prod/min/max/sqrt_n/mean).
This commit adds CPU/GPU implementations for prod/min/max
ops and python implementations for mean/sqrt_n. Also, it adapts and unifies the
corresponding tests of all unsorted reductions.
Note: The new gradient of unsorted_segment_max fixes the crash occuring when
negative indices on CPU are used.
* update golden API
* Fix compilation of atomicAdd for cuda_arch < 600. \n This commit moves the std::complex specialization of atomicAdd below the double specialization of atomicAdd for cuda_arch 600.
* Enable bfloat16, change inline to EIGEN_STRONG_INLINE.
* fix includes of cuda_device_functions; fix typo
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc | 138 |
1 files changed, 65 insertions, 73 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 39d520698e..ba979e6bb2 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -18,42 +18,15 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/kernels/segment_reduction_ops.h" - #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/cuda_device_functions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" + namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -// Helper for UnusortedSegmentSumCustomKernel that adds value into dest -// atomically. -template <typename T> -static __device__ __forceinline__ void AccumulateInto(T* dest, const T& value) { - CudaAtomicAdd(dest, value); -} - -// Specializations of AccumulateInto for complex types, which CudaAtomicAdd does -// not support. We treat a std::complex<T>* as a T* (the C++ standard section -// 26.4.4 allows this explicitly) and atomic add the real and imaginary -// components individually. The operation as a whole is not atomic, but we can -// safely treat the components independently for the purpose of accumulating. -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex<float>* dest, const std::complex<float>& value) { - auto dest_scalar = reinterpret_cast<float*>(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex<double>* dest, const std::complex<double>& value) { - auto dest_scalar = reinterpret_cast<double*>(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - // SortedSegmentSumFunctor kernel reduces input data just as // UnsortedSegmentSumCustomKernel does except that input data // is partitioned along the outer reduction dimension. This is @@ -81,7 +54,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, const Index* segment_ids, const T* input, T* output, const Index total_stripe_count) { - CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) { + for (int stripe_index : CudaGridRangeX(total_stripe_count)) { const Index segment_offset = stripe_index % inner_dim_size; const Index input_outer_dim_index_base = stripe_index / inner_dim_size * Index(OuterDimTileSize); @@ -106,7 +79,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // decide whether to write result to global memory using atomic // operations if (last_output_segment_id == first_segment_id) { - AccumulateInto<T>(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } else { *(output + output_index) = sum; } @@ -121,31 +94,31 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // the following strip. const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; - AccumulateInto<T>(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } } -// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements. +// UnsortedSegmentSumKernel processes 'input_total_size' elements. // Each element is mapped from input to output by a combination of its // 'segment_ids' mapping and 'inner_dim_size'. -template <typename T, typename Index> -__global__ void UnsortedSegmentSumCustomKernel( - const Index input_outer_dim_size, const Index inner_dim_size, - const Index output_outer_dim_size, const Index* segment_ids, const T* input, - T* output) { +template <typename T, typename Index, typename KernelReductionFunctor> +__global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size, + const Index inner_dim_size, + const Index output_outer_dim_size, + const Index* segment_ids, + const T* input, T* output) { const Index input_total_size = input_outer_dim_size * inner_dim_size; const Index output_total_size = output_outer_dim_size * inner_dim_size; - CUDA_1D_KERNEL_LOOP(input_index, input_total_size) { + for (int input_index : CudaGridRangeX(input_total_size)) { const Index input_segment_index = input_index / inner_dim_size; const Index segment_offset = input_index % inner_dim_size; const Index output_segment_index = segment_ids[input_segment_index]; - if (output_segment_index < 0 || output_segment_index >= output_total_size) { continue; } const Index output_index = output_segment_index * inner_dim_size + segment_offset; - AccumulateInto<T>(output + output_index, ldg(input + input_index)); + KernelReductionFunctor()(output + output_index, ldg(input + input_index)); } } @@ -190,41 +163,39 @@ void SegmentSumFunctor<T, Index>::operator()( <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( input_outer_dim_size, input_inner_dim_size, output_rows, segment_ids.data(), data, output.data(), total_stripe_count); -}; +} -// UnsortedSegmentSumFunctor implementation for GPUDevice. -template <typename T, typename Index> -struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> - : UnsortedSegmentBaseFunctor<GPUDevice, T, Index> { - void operator()(OpKernelContext* ctx, const GPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, +template <typename T, typename Index, typename InitialValueF, + typename ReductionF> +struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes<T, 2>::Tensor output) override { + typename TTypes<T, 2>::Tensor output) { if (output.size() == 0) { return; } - // Set 'output' to zeros. + // Set 'output' to initial value. + GPUDevice d = ctx->template eigen_device<GPUDevice>(); CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); - SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - output.size(), output.data()); + SetToValue<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + output.size(), output.data(), InitialValueF()()); if (data_size == 0 || segment_ids_shape.num_elements() == 0) { return; } - - // Launch kernel to compute unsorted segment sum. + // Launch kernel to compute unsorted segment reduction. // Notes: - // *) 'input_total_size' is the total number of elements to process. + // *) 'data_size' is the total number of elements to process. // *) 'segment_ids.shape' is a prefix of data's shape. // *) 'input_outer_dim_size' is the total number of segments to process. - const Index input_total_size = data_size; const Index input_outer_dim_size = segment_ids.dimension(0); - const Index input_inner_dim_size = input_total_size / input_outer_dim_size; + const Index input_inner_dim_size = data_size / input_outer_dim_size; + config = GetCudaLaunchConfig(data_size, d); - config = GetCudaLaunchConfig(input_total_size, d); - UnsortedSegmentSumCustomKernel<T, Index> + UnsortedSegmentCustomKernel<T, Index, ReductionF> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - input_outer_dim_size, input_inner_dim_size, output_rows, + input_outer_dim_size, input_inner_dim_size, num_segments, segment_ids.data(), data, output.data()); } }; @@ -238,19 +209,40 @@ struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - template struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> - -#define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPECS_INDEX(T, int32); \ - DEFINE_GPU_SPECS_INDEX(T, int64); - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); -TF_CALL_complex64(DEFINE_GPU_SPECS); -TF_CALL_complex128(DEFINE_GPU_SPECS); - -#undef DEFINE_GPU_SPECS -#undef DEFINE_GPU_SPECS_INDEX +#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>; \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>; \ + template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \ + functor::ProdOpGpu<T>>; + +// sum is the only op that supports all input types currently +#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>; + +#define DEFINE_REAL_GPU_SPECS(T) \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64); + +#define DEFINE_SUM_GPU_SPECS(T) \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS); +TF_CALL_int32(DEFINE_REAL_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS); +TF_CALL_int32(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex64(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex128(DEFINE_SUM_GPU_SPECS); + +#undef DEFINE_SORTED_GPU_SPECS_INDEX +#undef DEFINE_SORTED_GPU_SPECS +#undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_REAL_GPU_SPECS +#undef DEFINE_SUM_GPU_SPECS } // namespace functor } // namespace tensorflow |