aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b883350d..a7f57e94e3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@ def quantize(input, # pylint: disable=redefined-builtin
name=name)
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+ values,
+ side="left",
+ out_type=dtypes.int32,
+ name=None):
+ """Searches input tensor for values on the innermost dimension.
+
+ A 2-D example:
+
+ ```
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = searchsorted(sorted_sequence, values, side="left")
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+
+ result = searchsorted(sorted_sequence, values, side="right")
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+ ```
+
+ Args:
+ sorted_sequence: N-D `Tensor` containing a sorted sequence.
+ values: N-D `Tensor` containing the search values.
+ side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+ upper_bound.
+ out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
+ name: Optional name for the operation.
+
+ Returns:
+ An N-D `Tensor` the size of values containing the result of applying either
+ lower_bound or upper_bound (depending on side) to each value. The result
+ is not a global index to the entire `Tensor`, but the index in the last
+ dimension.
+
+ Raises:
+ ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+ If the total size of values exceeds `2^31 - 1` elements.
+ If the first `N-1` dimensions of the two tensors don't match.
+ """
+ sequence_size = shape_internal(sorted_sequence)[-1]
+ values_size = shape_internal(values)[-1]
+ sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+ values_2d = reshape(values, [-1, values_size])
+ if side == "right":
+ output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ elif side == "left":
+ output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ else:
+ raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side)
+ return reshape(output, shape_internal(values))
+
+
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__