diff options
Diffstat (limited to 'tensorflow/core/kernels/bucketize_op.cc')
-rw-r--r-- | tensorflow/core/kernels/bucketize_op.cc | 66 |
1 files changed, 50 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 93c2d01221..c1693de538 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -15,15 +15,43 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -#include <algorithm> -#include <vector> - +#include "tensorflow/core/kernels/bucketize_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { +using thread::ThreadPool; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + template <typename T> +struct BucketizeFunctor<CPUDevice, T> { + // PRECONDITION: boundaries_vector must be sorted. + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& input, + const std::vector<float>& boundaries_vector, + typename TTypes<int32, 1>::Tensor& output) { + const int N = input.size(); + for (int i = 0; i < N; i++) { + auto first_bigger_it = std::upper_bound( + boundaries_vector.begin(), boundaries_vector.end(), input(i)); + output(i) = first_bigger_it - boundaries_vector.begin(); + } + + return Status::OK(); + } +}; +} // namespace functor + +template <typename Device, typename T> class BucketizeOp : public OpKernel { public: explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) { @@ -34,36 +62,42 @@ class BucketizeOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); - auto input = input_tensor.flat<T>(); + const auto input = input_tensor.flat<T>(); + Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat<int32>(); - - const int N = input.size(); - for (int i = 0; i < N; i++) { - output(i) = CalculateBucketIndex(input(i)); - } + OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute( + context, input, boundaries_, output)); } private: - int32 CalculateBucketIndex(const T value) { - auto first_bigger_it = - std::upper_bound(boundaries_.begin(), boundaries_.end(), value); - return first_bigger_it - boundaries_.begin(); - } std::vector<float> boundaries_; }; #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Bucketize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - BucketizeOp<T>); + BucketizeOp<CPUDevice, T>); + +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Bucketize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + BucketizeOp<GPUDevice, T>); REGISTER_KERNEL(int32); REGISTER_KERNEL(int64); REGISTER_KERNEL(float); REGISTER_KERNEL(double); #undef REGISTER_KERNEL +#endif // GOOGLE_CUDA } // namespace tensorflow |