diff options
Diffstat (limited to 'tensorflow/core/kernels/bincount_op.cc')
-rw-r--r-- | tensorflow/core/kernels/bincount_op.cc | 115 |
1 files changed, 43 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index 766d63e3be..1cd5943ef3 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -17,7 +17,6 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/bincount_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" @@ -28,37 +27,46 @@ namespace tensorflow { using thread::ThreadPool; -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -namespace functor { - template <typename T> -struct BincountFunctor<CPUDevice, T> { - static Status Compute(OpKernelContext* context, - const typename TTypes<int32, 1>::ConstTensor& arr, - const typename TTypes<T, 1>::ConstTensor& weights, - typename TTypes<T, 1>::Tensor& output) { - int size = output.size(); +class BincountOp : public OpKernel { + public: + explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& arr_t = ctx->input(0); + const Tensor& size_tensor = ctx->input(1); + const Tensor& weights_t = ctx->input(2); + int32 size = size_tensor.scalar<int32>()(); + OP_REQUIRES( + ctx, size >= 0, + errors::InvalidArgument("size (", size, ") must be non-negative")); + const bool has_weights = weights_t.NumElements() > 0; + OP_REQUIRES(ctx, !(has_weights && arr_t.shape() != weights_t.shape()), + errors::InvalidArgument( + "If weights are passed, they must have the same shape (" + + weights_t.shape().DebugString() + ") as arr (" + + arr_t.shape().DebugString() + ")")); + const auto arr = arr_t.flat<int32>(); + const auto weights = weights_t.flat<T>(); Tensor all_nonneg_t; - TF_RETURN_IF_ERROR(context->allocate_temp( - DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes())); - all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) = + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_BOOL, TensorShape({}), &all_nonneg_t, + AllocatorAttributes())); + all_nonneg_t.scalar<bool>().device(ctx->eigen_cpu_device()) = (arr >= 0).all(); - if (!all_nonneg_t.scalar<bool>()()) { - return errors::InvalidArgument("Input arr must be non-negative!"); - } + OP_REQUIRES(ctx, all_nonneg_t.scalar<bool>()(), + errors::InvalidArgument("Input arr must be non-negative!")); // Allocate partial output bin sums for each worker thread. Worker ids in // ParallelForWithWorkerId range from 0 to NumThreads() inclusive. ThreadPool* thread_pool = - context->device()->tensorflow_cpu_worker_threads()->workers; + ctx->device()->tensorflow_cpu_worker_threads()->workers; const int64 num_threads = thread_pool->NumThreads() + 1; Tensor partial_bins_t; - TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({num_threads, size}), - &partial_bins_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(weights_t.dtype(), + TensorShape({num_threads, size}), + &partial_bins_t)); auto partial_bins = partial_bins_t.matrix<T>(); partial_bins.setZero(); thread_pool->ParallelForWithWorkerId( @@ -67,7 +75,7 @@ struct BincountFunctor<CPUDevice, T> { for (int64 i = start_ind; i < limit_ind; i++) { int32 value = arr(i); if (value < size) { - if (weights.size()) { + if (has_weights) { partial_bins(worker_id, value) += weights(i); } else { // Complex numbers don't support "++". @@ -76,62 +84,25 @@ struct BincountFunctor<CPUDevice, T> { } } }); - + TensorShape output_shape({size}); + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t)); // Sum the partial bins along the 0th axis. Eigen::array<int, 1> reduce_dims({0}); - output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dims); - return Status::OK(); - } -}; - -} // namespace functor - -template <typename Device, typename T> -class BincountOp : public OpKernel { - public: - explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor& arr_t = ctx->input(0); - const Tensor& size_tensor = ctx->input(1); - const Tensor& weights_t = ctx->input(2); - - int32 size = size_tensor.scalar<int32>()(); - OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( - "size (", size, ") must be non-negative")); - - const auto arr = arr_t.flat<int32>(); - const auto weights = weights_t.flat<T>(); - Tensor* output_t; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({size}), &output_t)); - auto output = output_t->flat<T>(); - OP_REQUIRES_OK(ctx, functor::BincountFunctor<Device, T>::Compute( - ctx, arr, weights, output)); + output_t->flat<T>().device(ctx->eigen_cpu_device()) = + partial_bins.sum(reduce_dims); } }; -#define REGISTER_KERNELS(type) \ +#define REGISTER(TYPE) \ REGISTER_KERNEL_BUILDER( \ - Name("Bincount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - BincountOp<CPUDevice, type>) - -TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); -#undef REGISTER_KERNELS - -#if GOOGLE_CUDA - -#define REGISTER_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("Bincount") \ - .Device(DEVICE_GPU) \ - .HostMemory("size") \ - .TypeConstraint<type>("T"), \ - BincountOp<GPUDevice, type>) + Name("Bincount").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ + BincountOp<TYPE>) -TF_CALL_int32(REGISTER_KERNELS); -TF_CALL_float(REGISTER_KERNELS); -#undef REGISTER_KERNELS +TF_CALL_NUMBER_TYPES(REGISTER); -#endif // GOOGLE_CUDA +// TODO(ringwalt): Add a GPU implementation. We probably want to take a +// different approach, e.g. threads in a warp each taking a pass over the same +// data, and each thread summing a single bin. } // end namespace tensorflow |