aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bucketize_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bucketize_op.cc')
-rw-r--r--tensorflow/core/kernels/bucketize_op.cc66
1 files changed, 50 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc
index 93c2d01221..c1693de538 100644
--- a/tensorflow/core/kernels/bucketize_op.cc
+++ b/tensorflow/core/kernels/bucketize_op.cc
@@ -15,15 +15,43 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
-#include <algorithm>
-#include <vector>
-
+#include "tensorflow/core/kernels/bucketize_op.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+using thread::ThreadPool;
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
template <typename T>
+struct BucketizeFunctor<CPUDevice, T> {
+ // PRECONDITION: boundaries_vector must be sorted.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& input,
+ const std::vector<float>& boundaries_vector,
+ typename TTypes<int32, 1>::Tensor& output) {
+ const int N = input.size();
+ for (int i = 0; i < N; i++) {
+ auto first_bigger_it = std::upper_bound(
+ boundaries_vector.begin(), boundaries_vector.end(), input(i));
+ output(i) = first_bigger_it - boundaries_vector.begin();
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T>
class BucketizeOp : public OpKernel {
public:
explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -34,36 +62,42 @@ class BucketizeOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
- auto input = input_tensor.flat<T>();
+ const auto input = input_tensor.flat<T>();
+
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
-
- const int N = input.size();
- for (int i = 0; i < N; i++) {
- output(i) = CalculateBucketIndex(input(i));
- }
+ OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute(
+ context, input, boundaries_, output));
}
private:
- int32 CalculateBucketIndex(const T value) {
- auto first_bigger_it =
- std::upper_bound(boundaries_.begin(), boundaries_.end(), value);
- return first_bigger_it - boundaries_.begin();
- }
std::vector<float> boundaries_;
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Bucketize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- BucketizeOp<T>);
+ BucketizeOp<CPUDevice, T>);
+
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(int64);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Bucketize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ BucketizeOp<GPUDevice, T>);
REGISTER_KERNEL(int32);
REGISTER_KERNEL(int64);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
+#endif // GOOGLE_CUDA
} // namespace tensorflow