diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-17 00:51:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 00:55:36 -0700 |
commit | 8535a56825361a80127da9165d27d6298819182c (patch) | |
tree | 5dba7ab8654882ccff6b0f1c7a55d6844e54fd53 /tensorflow/core/kernels/topk_op_gpu.cu.cc | |
parent | 337b12f11e9ef8710a326f105e3b26bb4d8fc0bf (diff) |
Update top_k implementation to use iterators for segment offsets
PiperOrigin-RevId: 165549078
Diffstat (limited to 'tensorflow/core/kernels/topk_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/topk_op_gpu.cu.cc | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc index 4b3f5ccacc..10a7602dc4 100644 --- a/tensorflow/core/kernels/topk_op_gpu.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc @@ -22,6 +22,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh" #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" +#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -404,11 +405,13 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards, } struct SegmentOffsetCreator { + EIGEN_DEVICE_FUNC SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {} - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( - const Eigen::array<int, 1>& ix) const { - return ix[0] * num_cols_; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; }; + int num_cols_; }; @@ -432,9 +435,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, const cudaStream_t& cu_stream = GetCudaStream(ctx); size_t temp_storage_bytes = -1; - // TODO(ebrevdo): Once cub supports iterators for the ValueT and - // segment_offsets, replace these tensors with iterators that - // directly return the correct value. + // TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor + // with an iterator that directly returns the correct value. Tensor input_indices; TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_INT32, TensorShape({num_rows, num_cols}), &input_indices)); @@ -442,12 +444,10 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, input_indices_t.device(d) = input_indices_t.generate(ColumnIndexCreator(num_cols)); - Tensor segment_offsets; - TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({num_rows + 1}), - &segment_offsets)); - auto segment_offsets_t = To32Bit(segment_offsets.flat<int32>()); - segment_offsets_t.device(d) = - segment_offsets_t.generate(SegmentOffsetCreator(num_cols)); + cub::CountingInputIterator<int> counting_iter(0); + cub::TransformInputIterator<int, SegmentOffsetCreator, + cub::CountingInputIterator<int>> + segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols)); Tensor temp_values; Tensor temp_indices; @@ -477,8 +477,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, /* d_values_out */ sorted_indices_ptr, /* num_items */ num_cols * num_rows, /* num_segments */ num_rows, - /* d_begin_offsets */ segment_offsets_t.data(), - /* d_end_offsets */ segment_offsets_t.data() + 1, + /* d_begin_offsets */ segment_offsets_t, + /* d_end_offsets */ segment_offsets_t + 1, /* begin_bit */ 0, /* end_bit */ sizeof(T) * 8, /* stream */ cu_stream); @@ -502,8 +502,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, /* d_values_out */ sorted_indices_ptr, /* num_items */ num_cols * num_rows, /* num_segments */ num_rows, - /* d_begin_offsets */ segment_offsets_t.data(), - /* d_end_offsets */ segment_offsets_t.data() + 1, + /* d_begin_offsets */ segment_offsets_t, + /* d_end_offsets */ segment_offsets_t + 1, /* begin_bit */ 0, /* end_bit */ sizeof(T) * 8, /* stream */ cu_stream); |