aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bincount_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bincount_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc114
1 files changed, 114 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
new file mode 100644
index 0000000000..6074b3e1f6
--- /dev/null
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bincount_op.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <typename T>
+struct BincountFunctor<GPUDevice, T> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<int32, 1>::ConstTensor& arr,
+ const typename TTypes<T, 1>::ConstTensor& weights,
+ typename TTypes<T, 1>::Tensor& output) {
+ if (weights.size() != 0) {
+ return errors::InvalidArgument(
+ "Weights should not be passed as it should be "
+ "handled by unsorted_segment_sum");
+ }
+ if (output.size() == 0) {
+ return Status::OK();
+ }
+ // In case weight.size() == 0, use CUB
+ size_t temp_storage_bytes = 0;
+ const int32* d_samples = arr.data();
+ T* d_histogram = output.data();
+ int num_levels = output.size() + 1;
+ int32 lower_level = 0;
+ int32 upper_level = output.size();
+ int num_samples = arr.size();
+ const cudaStream_t& stream = GetCudaStream(context);
+
+ // The first HistogramEven is to obtain the temp storage size required
+ // with d_temp_storage = NULL passed to the call.
+ auto err = cub::DeviceHistogram::HistogramEven(
+ /* d_temp_storage */ NULL,
+ /* temp_storage_bytes */ temp_storage_bytes,
+ /* d_samples */ d_samples,
+ /* d_histogram */ d_histogram,
+ /* num_levels */ num_levels,
+ /* lower_level */ lower_level,
+ /* upper_level */ upper_level,
+ /* num_samples */ num_samples,
+ /* stream */ stream);
+ if (err != cudaSuccess) {
+ return errors::Internal(
+ "Could not launch HistogramEven to get temp storage: ",
+ cudaGetErrorString(err), ".");
+ }
+ Tensor temp_storage;
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ DataTypeToEnum<int8>::value,
+ TensorShape({static_cast<int64>(temp_storage_bytes)}), &temp_storage));
+
+ void* d_temp_storage = temp_storage.flat<int8>().data();
+ // The second HistogramEven is to actual run with d_temp_storage
+ // allocated with temp_storage_bytes.
+ err = cub::DeviceHistogram::HistogramEven(
+ /* d_temp_storage */ d_temp_storage,
+ /* temp_storage_bytes */ temp_storage_bytes,
+ /* d_samples */ d_samples,
+ /* d_histogram */ d_histogram,
+ /* num_levels */ num_levels,
+ /* lower_level */ lower_level,
+ /* upper_level */ upper_level,
+ /* num_samples */ num_samples,
+ /* stream */ stream);
+ if (err != cudaSuccess) {
+ return errors::Internal(
+ "Could not launch HistogramEven: ", cudaGetErrorString(err), ".");
+ }
+ return Status::OK();
+ }
+};
+
+} // end namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::BincountFunctor<GPUDevice, type>;
+
+TF_CALL_int32(REGISTER_GPU_SPEC);
+TF_CALL_float(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA