diff options
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c8b883350d..a7f57e94e3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2787,4 +2787,65 @@ def quantize(input, # pylint: disable=redefined-builtin name=name) +@tf_export("searchsorted") +def searchsorted(sorted_sequence, + values, + side="left", + out_type=dtypes.int32, + name=None): + """Searches input tensor for values on the innermost dimension. + + A 2-D example: + + ``` + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = searchsorted(sorted_sequence, values, side="left") + + result == [[1, 2, 2], + [0, 1, 5]] + + result = searchsorted(sorted_sequence, values, side="right") + + result == [[1, 2, 4], + [0, 2, 5]] + ``` + + Args: + sorted_sequence: N-D `Tensor` containing a sorted sequence. + values: N-D `Tensor` containing the search values. + side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to + upper_bound. + out_type: The output type (`int32` or `int64`). Default is `tf.int32`. + name: Optional name for the operation. + + Returns: + An N-D `Tensor` the size of values containing the result of applying either + lower_bound or upper_bound (depending on side) to each value. The result + is not a global index to the entire `Tensor`, but the index in the last + dimension. + + Raises: + ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements. + If the total size of values exceeds `2^31 - 1` elements. + If the first `N-1` dimensions of the two tensors don't match. + """ + sequence_size = shape_internal(sorted_sequence)[-1] + values_size = shape_internal(values)[-1] + sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size]) + values_2d = reshape(values, [-1, values_size]) + if side == "right": + output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type, + name) + elif side == "left": + output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type, + name) + else: + raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side) + return reshape(output, shape_internal(values)) + + quantize.__doc__ = gen_array_ops.quantize_v2.__doc__ |