aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bincount_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bincount_op.cc')
-rw-r--r--tensorflow/core/kernels/bincount_op.cc115
1 files changed, 72 insertions, 43 deletions
diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc
index 1cd5943ef3..766d63e3be 100644
--- a/tensorflow/core/kernels/bincount_op.cc
+++ b/tensorflow/core/kernels/bincount_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
@@ -27,46 +28,37 @@ namespace tensorflow {
using thread::ThreadPool;
-template <typename T>
-class BincountOp : public OpKernel {
- public:
- explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
- void Compute(OpKernelContext* ctx) override {
- const Tensor& arr_t = ctx->input(0);
- const Tensor& size_tensor = ctx->input(1);
- const Tensor& weights_t = ctx->input(2);
- int32 size = size_tensor.scalar<int32>()();
- OP_REQUIRES(
- ctx, size >= 0,
- errors::InvalidArgument("size (", size, ") must be non-negative"));
- const bool has_weights = weights_t.NumElements() > 0;
- OP_REQUIRES(ctx, !(has_weights && arr_t.shape() != weights_t.shape()),
- errors::InvalidArgument(
- "If weights are passed, they must have the same shape (" +
- weights_t.shape().DebugString() + ") as arr (" +
- arr_t.shape().DebugString() + ")"));
- const auto arr = arr_t.flat<int32>();
- const auto weights = weights_t.flat<T>();
+namespace functor {
+
+template <typename T>
+struct BincountFunctor<CPUDevice, T> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<int32, 1>::ConstTensor& arr,
+ const typename TTypes<T, 1>::ConstTensor& weights,
+ typename TTypes<T, 1>::Tensor& output) {
+ int size = output.size();
Tensor all_nonneg_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_BOOL, TensorShape({}), &all_nonneg_t,
- AllocatorAttributes()));
- all_nonneg_t.scalar<bool>().device(ctx->eigen_cpu_device()) =
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes()));
+ all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) =
(arr >= 0).all();
- OP_REQUIRES(ctx, all_nonneg_t.scalar<bool>()(),
- errors::InvalidArgument("Input arr must be non-negative!"));
+ if (!all_nonneg_t.scalar<bool>()()) {
+ return errors::InvalidArgument("Input arr must be non-negative!");
+ }
// Allocate partial output bin sums for each worker thread. Worker ids in
// ParallelForWithWorkerId range from 0 to NumThreads() inclusive.
ThreadPool* thread_pool =
- ctx->device()->tensorflow_cpu_worker_threads()->workers;
+ context->device()->tensorflow_cpu_worker_threads()->workers;
const int64 num_threads = thread_pool->NumThreads() + 1;
Tensor partial_bins_t;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(weights_t.dtype(),
- TensorShape({num_threads, size}),
- &partial_bins_t));
+ TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({num_threads, size}),
+ &partial_bins_t));
auto partial_bins = partial_bins_t.matrix<T>();
partial_bins.setZero();
thread_pool->ParallelForWithWorkerId(
@@ -75,7 +67,7 @@ class BincountOp : public OpKernel {
for (int64 i = start_ind; i < limit_ind; i++) {
int32 value = arr(i);
if (value < size) {
- if (has_weights) {
+ if (weights.size()) {
partial_bins(worker_id, value) += weights(i);
} else {
// Complex numbers don't support "++".
@@ -84,25 +76,62 @@ class BincountOp : public OpKernel {
}
}
});
- TensorShape output_shape({size});
- Tensor* output_t;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
+
// Sum the partial bins along the 0th axis.
Eigen::array<int, 1> reduce_dims({0});
- output_t->flat<T>().device(ctx->eigen_cpu_device()) =
- partial_bins.sum(reduce_dims);
+ output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dims);
+ return Status::OK();
+ }
+};
+
+} // namespace functor
+
+template <typename Device, typename T>
+class BincountOp : public OpKernel {
+ public:
+ explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& arr_t = ctx->input(0);
+ const Tensor& size_tensor = ctx->input(1);
+ const Tensor& weights_t = ctx->input(2);
+
+ int32 size = size_tensor.scalar<int32>()();
+ OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument(
+ "size (", size, ") must be non-negative"));
+
+ const auto arr = arr_t.flat<int32>();
+ const auto weights = weights_t.flat<T>();
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({size}), &output_t));
+ auto output = output_t->flat<T>();
+ OP_REQUIRES_OK(ctx, functor::BincountFunctor<Device, T>::Compute(
+ ctx, arr, weights, output));
}
};
-#define REGISTER(TYPE) \
+#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
- Name("Bincount").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
- BincountOp<TYPE>)
+ Name("Bincount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ BincountOp<CPUDevice, type>)
+
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Bincount") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("size") \
+ .TypeConstraint<type>("T"), \
+ BincountOp<GPUDevice, type>)
-TF_CALL_NUMBER_TYPES(REGISTER);
+TF_CALL_int32(REGISTER_KERNELS);
+TF_CALL_float(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
-// TODO(ringwalt): Add a GPU implementation. We probably want to take a
-// different approach, e.g. threads in a warp each taking a pass over the same
-// data, and each thread summing a single bin.
+#endif // GOOGLE_CUDA
} // end namespace tensorflow