diff options
Diffstat (limited to 'tensorflow/core/kernels/relu_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/relu_op_gpu.cu.cc | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 089ca8ed27..b9391517c1 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -103,7 +103,7 @@ struct ReluGrad<Device, Eigen::half> { int32 count = gradient.size(); if (count == 0) return; int32 half2_count = Eigen::divup(count, 2); - const int32 kThreadInBlock = 512; + constexpr int32 kThreadInBlock = 512; CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock); ReluGradHalfKernel<<<config.block_count, config.thread_per_block, 0, @@ -111,6 +111,37 @@ struct ReluGrad<Device, Eigen::half> { backprop.data(), count); } }; + +__global__ void Relu_int8x4_kernel(int vect_count, const int32* input, + int32* output) { + CUDA_1D_KERNEL_LOOP(index, vect_count) { + output[index] = __vmaxs4(input[index], 0); + } +} + +// Functor used by ReluOp to do the computations. +template <typename Device> +struct Relu<Device, qint8> { + // Computes Relu activation of 'input' containing int8 elements, whose buffer + // size should be a multiple of 4, and aligned to an int32* boundary. + // (Alignment should be guaranteed by the GPU tensor allocator). + // 'output' should have the same size as 'input'. + void operator()(const Device& d, typename TTypes<qint8>::ConstTensor input, + typename TTypes<qint8>::Tensor output) { + int32 count = input.size(); + if (count == 0) return; + + int32 vect_count = Eigen::divup(count, 4); + constexpr int32 kThreadInBlock = 512; + CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock); + Relu_int8x4_kernel<<<config.block_count, config.thread_per_block, 0, + d.stream()>>>( + vect_count, reinterpret_cast<const int32*>(input.data()), + reinterpret_cast<int32*>(output.data())); + } +}; + } // namespace functor // Definition of the GPU implementations declared in relu_op.cc. @@ -126,6 +157,8 @@ struct ReluGrad<Device, Eigen::half> { TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); +template struct functor::Relu<GPUDevice, qint8>; + } // end namespace tensorflow #endif // GOOGLE_CUDA |