aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 12:37:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 12:44:24 -0700
commit4aa639c0cbb47f4707f735e0cc80f4c39506d928 (patch)
treeee87e1b72e0c6a9482e1005f88706f7c193e7a33 /tensorflow/core/util
parent350effcc2fd95c723c92267cf13fcd38777a2a98 (diff)
Add searchsorted (ie lower/upper bound) op.
PiperOrigin-RevId: 213863392
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h31
1 files changed, 26 insertions, 5 deletions
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb58d4..f6f0408ccc 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
}
namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
- IntType* orig = first;
- IntType* it = nullptr;
- IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
while (count > 0) {
it = first;
step = count / 2;
@@ -112,6 +112,27 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
+ while (count > 0) {
+ it = first;
+ step = count / 2;
+ it += step;
+ if (*it < val) {
+ first = ++it;
+ count -= step + 1;
+ } else {
+ count = step;
+ }
+ }
+
+ return first - orig;
+}
+
} // namespace cuda_helper
} // namespace tensorflow