From 305712c02e70bc860812e7c151a3842f028cacb1 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Fri, 7 Jul 2017 16:35:46 -0700 Subject: More WhereOp/TopK GPU bugfixes: use the direct cuda stream for CUB GPU kernel. Turns out using the StreamInterface objects leads to "invalid resource handle" errors, so we have to use the cudaStream_t directly. This change is based on similar code in cuda_solvers.cc. PiperOrigin-RevId: 161261085 --- tensorflow/core/kernels/topk_op_gpu.cu.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'tensorflow/core/kernels/topk_op_gpu.cu.cc') diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc index e4b4a3cb49..4b3f5ccacc 100644 --- a/tensorflow/core/kernels/topk_op_gpu.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" // Required for sorting Eigen::half namespace cub { @@ -365,9 +366,9 @@ __global__ void TopKKernel(const T* input, int length, int k, bool sorted, } template -cudaError LaunchTopKKernel(cudaStream_t stream, int num_shards, const T* input, - int batch_size, int length, int k, bool sorted, - T* output, int* indices) { +cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards, + const T* input, int batch_size, int length, int k, + bool sorted, T* output, int* indices) { // This code assumes that k is small enough that the computation // fits inside shared memory (hard coded to 48KB). In practice this // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64. @@ -428,7 +429,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, typename TTypes::Tensor values, TTypes::Tensor indices) { const GPUDevice& d = ctx->eigen_device(); - auto stream = ctx->eigen_gpu_device().stream(); + const cudaStream_t& cu_stream = GetCudaStream(ctx); size_t temp_storage_bytes = -1; // TODO(ebrevdo): Once cub supports iterators for the ValueT and @@ -480,7 +481,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, /* d_end_offsets */ segment_offsets_t.data() + 1, /* begin_bit */ 0, /* end_bit */ sizeof(T) * 8, - /* stream */ stream); + /* stream */ cu_stream); if (err != cudaSuccess) { return errors::Internal( "TopKOp: Could not launch " @@ -505,7 +506,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, /* d_end_offsets */ segment_offsets_t.data() + 1, /* begin_bit */ 0, /* end_bit */ sizeof(T) * 8, - /* stream */ stream); + /* stream */ cu_stream); if (err != cudaSuccess) { return errors::Internal( "TopKOp: Could not launch " @@ -545,8 +546,8 @@ struct TopKFunctor { return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols, k, values, indices); } else { - auto stream = context->eigen_gpu_device().stream(); - auto err = impl::LaunchTopKKernel(stream, /* num_shards */ 0, + const cudaStream_t& cu_stream = GetCudaStream(context); + auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0, input.data(), num_rows, num_cols, k, sorted, values.data(), indices.data()); if (err != cudaSuccess) { -- cgit v1.2.3