diff options
Diffstat (limited to 'tensorflow/core/kernels/bucketize_op.cc')
-rw-r--r-- | tensorflow/core/kernels/bucketize_op.cc | 66 |
1 files changed, 16 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index c1693de538..93c2d01221 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -15,43 +15,15 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -#include "tensorflow/core/kernels/bucketize_op.h" +#include <algorithm> +#include <vector> + #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { -using thread::ThreadPool; - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -namespace functor { - template <typename T> -struct BucketizeFunctor<CPUDevice, T> { - // PRECONDITION: boundaries_vector must be sorted. - static Status Compute(OpKernelContext* context, - const typename TTypes<T, 1>::ConstTensor& input, - const std::vector<float>& boundaries_vector, - typename TTypes<int32, 1>::Tensor& output) { - const int N = input.size(); - for (int i = 0; i < N; i++) { - auto first_bigger_it = std::upper_bound( - boundaries_vector.begin(), boundaries_vector.end(), input(i)); - output(i) = first_bigger_it - boundaries_vector.begin(); - } - - return Status::OK(); - } -}; -} // namespace functor - -template <typename Device, typename T> class BucketizeOp : public OpKernel { public: explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) { @@ -62,42 +34,36 @@ class BucketizeOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); - const auto input = input_tensor.flat<T>(); - + auto input = input_tensor.flat<T>(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat<int32>(); - OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute( - context, input, boundaries_, output)); + + const int N = input.size(); + for (int i = 0; i < N; i++) { + output(i) = CalculateBucketIndex(input(i)); + } } private: + int32 CalculateBucketIndex(const T value) { + auto first_bigger_it = + std::upper_bound(boundaries_.begin(), boundaries_.end(), value); + return first_bigger_it - boundaries_.begin(); + } std::vector<float> boundaries_; }; #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Bucketize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - BucketizeOp<CPUDevice, T>); - -REGISTER_KERNEL(int32); -REGISTER_KERNEL(int64); -REGISTER_KERNEL(float); -REGISTER_KERNEL(double); -#undef REGISTER_KERNEL - -#if GOOGLE_CUDA -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("Bucketize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ - BucketizeOp<GPUDevice, T>); + BucketizeOp<T>); REGISTER_KERNEL(int32); REGISTER_KERNEL(int64); REGISTER_KERNEL(float); REGISTER_KERNEL(double); #undef REGISTER_KERNEL -#endif // GOOGLE_CUDA } // namespace tensorflow |