/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/lib/strings/str_util.h" #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" namespace tensorflow { namespace { template __device__ __host__ EIGEN_STRONG_INLINE typename std::enable_if::value, U>::type strict_cast(T t); template __device__ __host__ EIGEN_STRONG_INLINE typename std::enable_if::value, U>::type strict_cast(T t) { return t; } template <> __device__ __host__ EIGEN_STRONG_INLINE float strict_cast( Eigen::half t) { return functor::HalfToFloat()(t); } template <> __device__ __host__ EIGEN_STRONG_INLINE Eigen::half strict_cast(float t) { return functor::FloatToHalf()(t); } template struct softmax_traits { using accumulator_type = T; }; template <> struct softmax_traits { using accumulator_type = float; }; template __global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs, const T* max_logits, T* output, const int num_rows, const int num_cols, const bool in_log_space) { const int tid = blockIdx.x * blockDim.x + threadIdx.x; const int row = tid / num_cols; const int col = tid % num_cols; // TODO(jamesqin): change to half2 load when inputs are Eigen::half. U input = strict_cast(logits[tid]); U max_val = strict_cast(ldg(max_logits + row)); U result; if (row < num_rows && col < num_cols) { if (in_log_space) { result = input - max_val - log(ldg(sum_probs + row)); } else { result = exp(input - max_val) / ldg(sum_probs + row); } output[tid] = strict_cast(result); } } template struct SubtractAndExpFunctor { __host__ __device__ SubtractAndExpFunctor(const T* logits, const T* max_logits, const int num_cols) : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {} __host__ __device__ U operator()(const int gid) const { // TODO(jamesqin): change to half2 load when inputs are Eigen::half. const U diff = strict_cast(logits_[gid] - ldg(max_logits_ + gid / num_cols_)); return exp(diff); } const T* logits_; const T* max_logits_; const int num_cols_; }; template void DoRowReduction(OpKernelContext* context, T* output, InputIter input, int rows, int cols) { typedef const Eigen::array::Tensor::Index, 1>& ReductionAxes; Constants constants; Op op; functor::ReduceImpl( context, output, input, 2, rows, cols, 1, 1, constants.kOne, op); } } // namespace template class SoftmaxOpGPU : public OpKernel { public: explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) { log_ = str_util::StartsWith(type_string(), "Log"); } void Compute(OpKernelContext* context) override { const Tensor& logits_in_ = context->input(0); OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()), errors::InvalidArgument("logits must have >= 1 dimension, got ", logits_in_.shape().DebugString())); auto logits_in = logits_in_.flat_inner_dims(); const int rows = logits_in.dimension(0); const int cols = logits_in.dimension(1); Tensor* softmax_out = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, logits_in_.shape(), &softmax_out)); const cudaStream_t& cu_stream = GetCudaStream(context); if (logits_in_.NumElements() > 0) { Tensor max_logits; Tensor sum_probs; OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, softmax_out->shape(), &max_logits)); typedef typename softmax_traits::accumulator_type acc_type; OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, softmax_out->shape(), &sum_probs)); DoRowReduction( context, const_cast(max_logits.flat().data()), reinterpret_cast(logits_in_.flat().data()), rows, cols); const int numThreads = 128; const int numBlocks = Eigen::divup(rows * cols, numThreads); cub::CountingInputIterator counting_iterator(0); typedef cub::TransformInputIterator, cub::CountingInputIterator> InputIterType; InputIterType input_itr( counting_iterator, SubtractAndExpFunctor( reinterpret_cast(logits_in_.flat().data()), reinterpret_cast(max_logits.flat().data()), cols)); DoRowReduction( context, const_cast(sum_probs.flat().data()), input_itr, rows, cols); GenerateNormalizedProb <<>>( reinterpret_cast(logits_in_.flat().data()), reinterpret_cast( sum_probs.flat().data()), reinterpret_cast(max_logits.flat().data()), const_cast(softmax_out->flat().data()), rows, cols, log_); } } private: bool log_; }; REGISTER_KERNEL_BUILDER( Name("Softmax").Device(DEVICE_GPU).TypeConstraint("T"), SoftmaxOpGPU); REGISTER_KERNEL_BUILDER( Name("Softmax").Device(DEVICE_GPU).TypeConstraint("T"), SoftmaxOpGPU); REGISTER_KERNEL_BUILDER( Name("Softmax").Device(DEVICE_GPU).TypeConstraint("T"), SoftmaxOpGPU); REGISTER_KERNEL_BUILDER( Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint("T"), SoftmaxOpGPU); REGISTER_KERNEL_BUILDER( Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint("T"), SoftmaxOpGPU); } // end namespace tensorflow #endif // GOOGLE_CUDA