aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bincount_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bincount_op.cc')
-rw-r--r--tensorflow/core/kernels/bincount_op.cc115
1 files changed, 43 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc
index 766d63e3be..1cd5943ef3 100644
--- a/tensorflow/core/kernels/bincount_op.cc
+++ b/tensorflow/core/kernels/bincount_op.cc
@@ -17,7 +17,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
@@ -28,37 +27,46 @@ namespace tensorflow {
using thread::ThreadPool;
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
-
-namespace functor {
-
template <typename T>
-struct BincountFunctor<CPUDevice, T> {
- static Status Compute(OpKernelContext* context,
- const typename TTypes<int32, 1>::ConstTensor& arr,
- const typename TTypes<T, 1>::ConstTensor& weights,
- typename TTypes<T, 1>::Tensor& output) {
- int size = output.size();
+class BincountOp : public OpKernel {
+ public:
+ explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& arr_t = ctx->input(0);
+ const Tensor& size_tensor = ctx->input(1);
+ const Tensor& weights_t = ctx->input(2);
+ int32 size = size_tensor.scalar<int32>()();
+ OP_REQUIRES(
+ ctx, size >= 0,
+ errors::InvalidArgument("size (", size, ") must be non-negative"));
+ const bool has_weights = weights_t.NumElements() > 0;
+ OP_REQUIRES(ctx, !(has_weights && arr_t.shape() != weights_t.shape()),
+ errors::InvalidArgument(
+ "If weights are passed, they must have the same shape (" +
+ weights_t.shape().DebugString() + ") as arr (" +
+ arr_t.shape().DebugString() + ")"));
+ const auto arr = arr_t.flat<int32>();
+ const auto weights = weights_t.flat<T>();
Tensor all_nonneg_t;
- TF_RETURN_IF_ERROR(context->allocate_temp(
- DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes()));
- all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) =
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DT_BOOL, TensorShape({}), &all_nonneg_t,
+ AllocatorAttributes()));
+ all_nonneg_t.scalar<bool>().device(ctx->eigen_cpu_device()) =
(arr >= 0).all();
- if (!all_nonneg_t.scalar<bool>()()) {
- return errors::InvalidArgument("Input arr must be non-negative!");
- }
+ OP_REQUIRES(ctx, all_nonneg_t.scalar<bool>()(),
+ errors::InvalidArgument("Input arr must be non-negative!"));
// Allocate partial output bin sums for each worker thread. Worker ids in
// ParallelForWithWorkerId range from 0 to NumThreads() inclusive.
ThreadPool* thread_pool =
- context->device()->tensorflow_cpu_worker_threads()->workers;
+ ctx->device()->tensorflow_cpu_worker_threads()->workers;
const int64 num_threads = thread_pool->NumThreads() + 1;
Tensor partial_bins_t;
- TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({num_threads, size}),
- &partial_bins_t));
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(weights_t.dtype(),
+ TensorShape({num_threads, size}),
+ &partial_bins_t));
auto partial_bins = partial_bins_t.matrix<T>();
partial_bins.setZero();
thread_pool->ParallelForWithWorkerId(
@@ -67,7 +75,7 @@ struct BincountFunctor<CPUDevice, T> {
for (int64 i = start_ind; i < limit_ind; i++) {
int32 value = arr(i);
if (value < size) {
- if (weights.size()) {
+ if (has_weights) {
partial_bins(worker_id, value) += weights(i);
} else {
// Complex numbers don't support "++".
@@ -76,62 +84,25 @@ struct BincountFunctor<CPUDevice, T> {
}
}
});
-
+ TensorShape output_shape({size});
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
// Sum the partial bins along the 0th axis.
Eigen::array<int, 1> reduce_dims({0});
- output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dims);
- return Status::OK();
- }
-};
-
-} // namespace functor
-
-template <typename Device, typename T>
-class BincountOp : public OpKernel {
- public:
- explicit BincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& arr_t = ctx->input(0);
- const Tensor& size_tensor = ctx->input(1);
- const Tensor& weights_t = ctx->input(2);
-
- int32 size = size_tensor.scalar<int32>()();
- OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument(
- "size (", size, ") must be non-negative"));
-
- const auto arr = arr_t.flat<int32>();
- const auto weights = weights_t.flat<T>();
- Tensor* output_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({size}), &output_t));
- auto output = output_t->flat<T>();
- OP_REQUIRES_OK(ctx, functor::BincountFunctor<Device, T>::Compute(
- ctx, arr, weights, output));
+ output_t->flat<T>().device(ctx->eigen_cpu_device()) =
+ partial_bins.sum(reduce_dims);
}
};
-#define REGISTER_KERNELS(type) \
+#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
- Name("Bincount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- BincountOp<CPUDevice, type>)
-
-TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
-#undef REGISTER_KERNELS
-
-#if GOOGLE_CUDA
-
-#define REGISTER_KERNELS(type) \
- REGISTER_KERNEL_BUILDER(Name("Bincount") \
- .Device(DEVICE_GPU) \
- .HostMemory("size") \
- .TypeConstraint<type>("T"), \
- BincountOp<GPUDevice, type>)
+ Name("Bincount").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
+ BincountOp<TYPE>)
-TF_CALL_int32(REGISTER_KERNELS);
-TF_CALL_float(REGISTER_KERNELS);
-#undef REGISTER_KERNELS
+TF_CALL_NUMBER_TYPES(REGISTER);
-#endif // GOOGLE_CUDA
+// TODO(ringwalt): Add a GPU implementation. We probably want to take a
+// different approach, e.g. threads in a warp each taking a pass over the same
+// data, and each thread summing a single bin.
} // end namespace tensorflow