aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/topk_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-07-07 16:35:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-07 16:40:38 -0700
commit305712c02e70bc860812e7c151a3842f028cacb1 (patch)
tree78d30c0273970eb93cad4fe68620cdc72170eda3 /tensorflow/core/kernels/topk_op_gpu.cu.cc
parentd9c4732a1dd7466c5c98e09ed034876c45724f8b (diff)
More WhereOp/TopK GPU bugfixes: use the direct cuda stream for CUB GPU kernel.
Turns out using the StreamInterface objects leads to "invalid resource handle" errors, so we have to use the cudaStream_t directly. This change is based on similar code in cuda_solvers.cc. PiperOrigin-RevId: 161261085
Diffstat (limited to 'tensorflow/core/kernels/topk_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc17
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index e4b4a3cb49..4b3f5ccacc 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
// Required for sorting Eigen::half
namespace cub {
@@ -365,9 +366,9 @@ __global__ void TopKKernel(const T* input, int length, int k, bool sorted,
}
template <typename T>
-cudaError LaunchTopKKernel(cudaStream_t stream, int num_shards, const T* input,
- int batch_size, int length, int k, bool sorted,
- T* output, int* indices) {
+cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
+ const T* input, int batch_size, int length, int k,
+ bool sorted, T* output, int* indices) {
// This code assumes that k is small enough that the computation
// fits inside shared memory (hard coded to 48KB). In practice this
// means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
@@ -428,7 +429,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
typename TTypes<T, 2>::Tensor values,
TTypes<int, 2>::Tensor indices) {
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
- auto stream = ctx->eigen_gpu_device().stream();
+ const cudaStream_t& cu_stream = GetCudaStream(ctx);
size_t temp_storage_bytes = -1;
// TODO(ebrevdo): Once cub supports iterators for the ValueT and
@@ -480,7 +481,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
/* d_end_offsets */ segment_offsets_t.data() + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
- /* stream */ stream);
+ /* stream */ cu_stream);
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
@@ -505,7 +506,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
/* d_end_offsets */ segment_offsets_t.data() + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
- /* stream */ stream);
+ /* stream */ cu_stream);
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
@@ -545,8 +546,8 @@ struct TopKFunctor<GPUDevice, T> {
return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
k, values, indices);
} else {
- auto stream = context->eigen_gpu_device().stream();
- auto err = impl::LaunchTopKKernel(stream, /* num_shards */ 0,
+ const cudaStream_t& cu_stream = GetCudaStream(context);
+ auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
input.data(), num_rows, num_cols, k,
sorted, values.data(), indices.data());
if (err != cudaSuccess) {