diff options
Diffstat (limited to 'tensorflow/core/kernels/bincount_op.cc')
-rw-r--r-- | tensorflow/core/kernels/bincount_op.cc | 115 |
1 files changed, 72 insertions, 43 deletions
diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index 1cd5943ef3..766d63e3be 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/bincount_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" @@ -27,46 +28,37 @@ namespace tensorflow { using thread::ThreadPool; -template <typename T> -class BincountOp : public OpKernel { - public: - explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; - void Compute(OpKernelContext* ctx) override { - const Tensor& arr_t = ctx->input(0); - const Tensor& size_tensor = ctx->input(1); - const Tensor& weights_t = ctx->input(2); - int32 size = size_tensor.scalar<int32>()(); - OP_REQUIRES( - ctx, size >= 0, - errors::InvalidArgument("size (", size, ") must be non-negative")); - const bool has_weights = weights_t.NumElements() > 0; - OP_REQUIRES(ctx, !(has_weights && arr_t.shape() != weights_t.shape()), - errors::InvalidArgument( - "If weights are passed, they must have the same shape (" + - weights_t.shape().DebugString() + ") as arr (" + - arr_t.shape().DebugString() + ")")); - const auto arr = arr_t.flat<int32>(); - const auto weights = weights_t.flat<T>(); +namespace functor { + +template <typename T> +struct BincountFunctor<CPUDevice, T> { + static Status Compute(OpKernelContext* context, + const typename TTypes<int32, 1>::ConstTensor& arr, + const typename TTypes<T, 1>::ConstTensor& weights, + typename TTypes<T, 1>::Tensor& output) { + int size = output.size(); Tensor all_nonneg_t; - OP_REQUIRES_OK(ctx, - ctx->allocate_temp(DT_BOOL, TensorShape({}), &all_nonneg_t, - AllocatorAttributes())); - all_nonneg_t.scalar<bool>().device(ctx->eigen_cpu_device()) = + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes())); + all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) = (arr >= 0).all(); - OP_REQUIRES(ctx, all_nonneg_t.scalar<bool>()(), - errors::InvalidArgument("Input arr must be non-negative!")); + if (!all_nonneg_t.scalar<bool>()()) { + return errors::InvalidArgument("Input arr must be non-negative!"); + } // Allocate partial output bin sums for each worker thread. Worker ids in // ParallelForWithWorkerId range from 0 to NumThreads() inclusive. ThreadPool* thread_pool = - ctx->device()->tensorflow_cpu_worker_threads()->workers; + context->device()->tensorflow_cpu_worker_threads()->workers; const int64 num_threads = thread_pool->NumThreads() + 1; Tensor partial_bins_t; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(weights_t.dtype(), - TensorShape({num_threads, size}), - &partial_bins_t)); + TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({num_threads, size}), + &partial_bins_t)); auto partial_bins = partial_bins_t.matrix<T>(); partial_bins.setZero(); thread_pool->ParallelForWithWorkerId( @@ -75,7 +67,7 @@ class BincountOp : public OpKernel { for (int64 i = start_ind; i < limit_ind; i++) { int32 value = arr(i); if (value < size) { - if (has_weights) { + if (weights.size()) { partial_bins(worker_id, value) += weights(i); } else { // Complex numbers don't support "++". @@ -84,25 +76,62 @@ class BincountOp : public OpKernel { } } }); - TensorShape output_shape({size}); - Tensor* output_t; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t)); + // Sum the partial bins along the 0th axis. Eigen::array<int, 1> reduce_dims({0}); - output_t->flat<T>().device(ctx->eigen_cpu_device()) = - partial_bins.sum(reduce_dims); + output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dims); + return Status::OK(); + } +}; + +} // namespace functor + +template <typename Device, typename T> +class BincountOp : public OpKernel { + public: + explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& arr_t = ctx->input(0); + const Tensor& size_tensor = ctx->input(1); + const Tensor& weights_t = ctx->input(2); + + int32 size = size_tensor.scalar<int32>()(); + OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( + "size (", size, ") must be non-negative")); + + const auto arr = arr_t.flat<int32>(); + const auto weights = weights_t.flat<T>(); + Tensor* output_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({size}), &output_t)); + auto output = output_t->flat<T>(); + OP_REQUIRES_OK(ctx, functor::BincountFunctor<Device, T>::Compute( + ctx, arr, weights, output)); } }; -#define REGISTER(TYPE) \ +#define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ - Name("Bincount").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ - BincountOp<TYPE>) + Name("Bincount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + BincountOp<CPUDevice, type>) + +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("Bincount") \ + .Device(DEVICE_GPU) \ + .HostMemory("size") \ + .TypeConstraint<type>("T"), \ + BincountOp<GPUDevice, type>) -TF_CALL_NUMBER_TYPES(REGISTER); +TF_CALL_int32(REGISTER_KERNELS); +TF_CALL_float(REGISTER_KERNELS); +#undef REGISTER_KERNELS -// TODO(ringwalt): Add a GPU implementation. We probably want to take a -// different approach, e.g. threads in a warp each taking a pass over the same -// data, and each thread summing a single bin. +#endif // GOOGLE_CUDA } // end namespace tensorflow |