aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/topk_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 00:51:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 00:55:36 -0700
commit8535a56825361a80127da9165d27d6298819182c (patch)
tree5dba7ab8654882ccff6b0f1c7a55d6844e54fd53 /tensorflow/core/kernels/topk_op_gpu.cu.cc
parent337b12f11e9ef8710a326f105e3b26bb4d8fc0bf (diff)
Update top_k implementation to use iterators for segment offsets
PiperOrigin-RevId: 165549078
Diffstat (limited to 'tensorflow/core/kernels/topk_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc32
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index 4b3f5ccacc..10a7602dc4 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
+#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -404,11 +405,13 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
}
struct SegmentOffsetCreator {
+ EIGEN_DEVICE_FUNC
SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
- const Eigen::array<int, 1>& ix) const {
- return ix[0] * num_cols_;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
+ return idx * num_cols_;
};
+
int num_cols_;
};
@@ -432,9 +435,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
const cudaStream_t& cu_stream = GetCudaStream(ctx);
size_t temp_storage_bytes = -1;
- // TODO(ebrevdo): Once cub supports iterators for the ValueT and
- // segment_offsets, replace these tensors with iterators that
- // directly return the correct value.
+ // TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor
+ // with an iterator that directly returns the correct value.
Tensor input_indices;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
@@ -442,12 +444,10 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
input_indices_t.device(d) =
input_indices_t.generate(ColumnIndexCreator(num_cols));
- Tensor segment_offsets;
- TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({num_rows + 1}),
- &segment_offsets));
- auto segment_offsets_t = To32Bit(segment_offsets.flat<int32>());
- segment_offsets_t.device(d) =
- segment_offsets_t.generate(SegmentOffsetCreator(num_cols));
+ cub::CountingInputIterator<int> counting_iter(0);
+ cub::TransformInputIterator<int, SegmentOffsetCreator,
+ cub::CountingInputIterator<int>>
+ segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
Tensor temp_values;
Tensor temp_indices;
@@ -477,8 +477,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
/* d_values_out */ sorted_indices_ptr,
/* num_items */ num_cols * num_rows,
/* num_segments */ num_rows,
- /* d_begin_offsets */ segment_offsets_t.data(),
- /* d_end_offsets */ segment_offsets_t.data() + 1,
+ /* d_begin_offsets */ segment_offsets_t,
+ /* d_end_offsets */ segment_offsets_t + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
/* stream */ cu_stream);
@@ -502,8 +502,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
/* d_values_out */ sorted_indices_ptr,
/* num_items */ num_cols * num_rows,
/* num_segments */ num_rows,
- /* d_begin_offsets */ segment_offsets_t.data(),
- /* d_end_offsets */ segment_offsets_t.data() + 1,
+ /* d_begin_offsets */ segment_offsets_t,
+ /* d_end_offsets */ segment_offsets_t + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
/* stream */ cu_stream);