aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/searchsorted_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 12:37:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 12:44:24 -0700
commit4aa639c0cbb47f4707f735e0cc80f4c39506d928 (patch)
treeee87e1b72e0c6a9482e1005f88706f7c193e7a33 /tensorflow/core/kernels/searchsorted_op.cc
parent350effcc2fd95c723c92267cf13fcd38777a2a98 (diff)
Add searchsorted (ie lower/upper bound) op.
PiperOrigin-RevId: 213863392
Diffstat (limited to 'tensorflow/core/kernels/searchsorted_op.cc')
-rw-r--r--tensorflow/core/kernels/searchsorted_op.cc249
1 files changed, 249 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000000..dc627ac77a
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow