aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc143
1 files changed, 69 insertions, 74 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
index 39d520698e..3511c85f71 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
@@ -17,43 +17,19 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/segment_reduction_ops.h"
+// We need to include cuda_kernel_helper.h before segment_reduction_ops.h
+// See comment in segment_reduction_ops.h for more details.
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+#include "tensorflow/core/kernels/segment_reduction_ops.h"
#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/util/cuda_kernel_helper.h"
+#include "tensorflow/core/util/cuda_device_functions.h"
+
namespace tensorflow {
using GPUDevice = Eigen::GpuDevice;
-// Helper for UnusortedSegmentSumCustomKernel that adds value into dest
-// atomically.
-template <typename T>
-static __device__ __forceinline__ void AccumulateInto(T* dest, const T& value) {
- CudaAtomicAdd(dest, value);
-}
-
-// Specializations of AccumulateInto for complex types, which CudaAtomicAdd does
-// not support. We treat a std::complex<T>* as a T* (the C++ standard section
-// 26.4.4 allows this explicitly) and atomic add the real and imaginary
-// components individually. The operation as a whole is not atomic, but we can
-// safely treat the components independently for the purpose of accumulating.
-template <>
-__device__ __forceinline__ void AccumulateInto(
- std::complex<float>* dest, const std::complex<float>& value) {
- auto dest_scalar = reinterpret_cast<float*>(dest);
- CudaAtomicAdd(dest_scalar, value.real());
- CudaAtomicAdd(dest_scalar + 1, value.imag());
-}
-
-template <>
-__device__ __forceinline__ void AccumulateInto(
- std::complex<double>* dest, const std::complex<double>& value) {
- auto dest_scalar = reinterpret_cast<double*>(dest);
- CudaAtomicAdd(dest_scalar, value.real());
- CudaAtomicAdd(dest_scalar + 1, value.imag());
-}
-
// SortedSegmentSumFunctor kernel reduces input data just as
// UnsortedSegmentSumCustomKernel does except that input data
// is partitioned along the outer reduction dimension. This is
@@ -81,7 +57,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
const Index* segment_ids,
const T* input, T* output,
const Index total_stripe_count) {
- CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) {
+ for (int stripe_index : CudaGridRangeX(total_stripe_count)) {
const Index segment_offset = stripe_index % inner_dim_size;
const Index input_outer_dim_index_base =
stripe_index / inner_dim_size * Index(OuterDimTileSize);
@@ -106,7 +82,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
// decide whether to write result to global memory using atomic
// operations
if (last_output_segment_id == first_segment_id) {
- AccumulateInto<T>(output + output_index, sum);
+ CudaAtomicAdd(output + output_index, sum);
} else {
*(output + output_index) = sum;
}
@@ -121,31 +97,31 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
// the following strip.
const Index output_index =
last_output_segment_id * inner_dim_size + segment_offset;
- AccumulateInto<T>(output + output_index, sum);
+ CudaAtomicAdd(output + output_index, sum);
}
}
-// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements.
+// UnsortedSegmentSumKernel processes 'input_total_size' elements.
// Each element is mapped from input to output by a combination of its
// 'segment_ids' mapping and 'inner_dim_size'.
-template <typename T, typename Index>
-__global__ void UnsortedSegmentSumCustomKernel(
- const Index input_outer_dim_size, const Index inner_dim_size,
- const Index output_outer_dim_size, const Index* segment_ids, const T* input,
- T* output) {
+template <typename T, typename Index, typename KernelReductionFunctor>
+__global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size,
+ const Index inner_dim_size,
+ const Index output_outer_dim_size,
+ const Index* segment_ids,
+ const T* input, T* output) {
const Index input_total_size = input_outer_dim_size * inner_dim_size;
const Index output_total_size = output_outer_dim_size * inner_dim_size;
- CUDA_1D_KERNEL_LOOP(input_index, input_total_size) {
+ for (int input_index : CudaGridRangeX(input_total_size)) {
const Index input_segment_index = input_index / inner_dim_size;
const Index segment_offset = input_index % inner_dim_size;
const Index output_segment_index = segment_ids[input_segment_index];
-
if (output_segment_index < 0 || output_segment_index >= output_total_size) {
continue;
}
const Index output_index =
output_segment_index * inner_dim_size + segment_offset;
- AccumulateInto<T>(output + output_index, ldg(input + input_index));
+ KernelReductionFunctor()(output + output_index, ldg(input + input_index));
}
}
@@ -190,41 +166,39 @@ void SegmentSumFunctor<T, Index>::operator()(
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
input_outer_dim_size, input_inner_dim_size, output_rows,
segment_ids.data(), data, output.data(), total_stripe_count);
-};
+}
-// UnsortedSegmentSumFunctor implementation for GPUDevice.
-template <typename T, typename Index>
-struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>
- : UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
- void operator()(OpKernelContext* ctx, const GPUDevice& d,
- const Index output_rows, const TensorShape& segment_ids_shape,
+template <typename T, typename Index, typename InitialValueF,
+ typename ReductionF>
+struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
+ void operator()(OpKernelContext* ctx, const Index num_segments,
+ const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
- typename TTypes<T, 2>::Tensor output) override {
+ typename TTypes<T, 2>::Tensor output) {
if (output.size() == 0) {
return;
}
- // Set 'output' to zeros.
+ // Set 'output' to initial value.
+ GPUDevice d = ctx->template eigen_device<GPUDevice>();
CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
- SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- output.size(), output.data());
+ SetToValue<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ output.size(), output.data(), InitialValueF()());
if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
return;
}
-
- // Launch kernel to compute unsorted segment sum.
+ // Launch kernel to compute unsorted segment reduction.
// Notes:
- // *) 'input_total_size' is the total number of elements to process.
+ // *) 'data_size' is the total number of elements to process.
// *) 'segment_ids.shape' is a prefix of data's shape.
// *) 'input_outer_dim_size' is the total number of segments to process.
- const Index input_total_size = data_size;
const Index input_outer_dim_size = segment_ids.dimension(0);
- const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
+ const Index input_inner_dim_size = data_size / input_outer_dim_size;
+ config = GetCudaLaunchConfig(data_size, d);
- config = GetCudaLaunchConfig(input_total_size, d);
- UnsortedSegmentSumCustomKernel<T, Index>
+ UnsortedSegmentCustomKernel<T, Index, ReductionF>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- input_outer_dim_size, input_inner_dim_size, output_rows,
+ input_outer_dim_size, input_inner_dim_size, num_segments,
segment_ids.data(), data, output.data());
}
};
@@ -238,19 +212,40 @@ struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
-#define DEFINE_GPU_SPECS_INDEX(T, Index) \
- template struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>
-
-#define DEFINE_GPU_SPECS(T) \
- DEFINE_GPU_SPECS_INDEX(T, int32); \
- DEFINE_GPU_SPECS_INDEX(T, int64);
-
-TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
-TF_CALL_complex64(DEFINE_GPU_SPECS);
-TF_CALL_complex128(DEFINE_GPU_SPECS);
-
-#undef DEFINE_GPU_SPECS
-#undef DEFINE_GPU_SPECS_INDEX
+#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct UnsortedSegmentFunctor< \
+ GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>; \
+ template struct UnsortedSegmentFunctor< \
+ GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>; \
+ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
+ functor::ProdOpGpu<T>>;
+
+// sum is the only op that supports all input types currently
+#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct UnsortedSegmentFunctor< \
+ GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>;
+
+#define DEFINE_REAL_GPU_SPECS(T) \
+ DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
+ DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64);
+
+#define DEFINE_SUM_GPU_SPECS(T) \
+ DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \
+ DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64);
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS);
+TF_CALL_int32(DEFINE_REAL_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS);
+TF_CALL_int32(DEFINE_SUM_GPU_SPECS);
+TF_CALL_complex64(DEFINE_SUM_GPU_SPECS);
+TF_CALL_complex128(DEFINE_SUM_GPU_SPECS);
+
+#undef DEFINE_SORTED_GPU_SPECS_INDEX
+#undef DEFINE_SORTED_GPU_SPECS
+#undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX
+#undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX
+#undef DEFINE_REAL_GPU_SPECS
+#undef DEFINE_SUM_GPU_SPECS
} // namespace functor
} // namespace tensorflow