aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2018-03-22 12:45:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 12:48:29 -0700
commit830fc390b76b5eb138a7f59d0e13e83add653870 (patch)
treebc6af9e62d3d4eea0f96a173abaed2d042bf921e /tensorflow/contrib/framework
parentc7d11e1601d5045f5421c465a438a1d9632df78d (diff)
Add tf.contrib.framework.argsort, wrapping tf.nn.top_k (#288).
Comparable to np.argsort. PiperOrigin-RevId: 190109968
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/__init__.py1
-rw-r--r--tensorflow/contrib/framework/python/ops/sort_ops.py161
-rw-r--r--tensorflow/contrib/framework/python/ops/sort_ops_test.py34
3 files changed, 156 insertions, 40 deletions
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 3398b3fd1c..cbb68bd3eb 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -83,6 +83,7 @@ See the @{$python/contrib.framework} guide.
@@load_linear_multiclass_bias_initializer
@@load_variable_slot_initializer
+@@argsort
@@py_func
@@sort
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py
index 8f62f0ea7b..1921a77c1e 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Support for sorting tensors.
+@@argsort
@@sort
"""
@@ -21,6 +22,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@@ -47,64 +51,141 @@ def sort(values, axis=-1, direction='ASCENDING', name=None):
ValueError: If axis is not a constant scalar, or the direction is invalid.
"""
with framework_ops.name_scope(name, 'sort'):
- if direction not in _SORT_IMPL:
- raise ValueError('%s should be one of %s' %
- (direction, ', '.join(sorted(_SORT_IMPL.keys()))))
- # Axis must be an integer, not a Tensor.
- axis = framework_ops.convert_to_tensor(axis, name='axis')
- axis_static = tensor_util.constant_value(axis)
- if axis.shape.ndims != 0 or axis_static is None:
- raise ValueError('axis must be a constant scalar')
- axis_static = int(axis_static) # Avoids NumPy casting error
+ return _sort_or_argsort(values, axis, direction, return_argsort=False)
+
+
+def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
+ """Returns the indices of a tensor that give its sorted order along an axis.
+
+ For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to
+ `tf.sort(values)`. For higher dimensions, the output has the same shape as
+ `values`, but along the given axis, values represent the index of the sorted
+ element in that slice of the tensor at the given position.
+
+ Args:
+ values: 1-D or higher numeric `Tensor`.
+ axis: The axis along which to sort. The default is -1, which sorts the last
+ axis.
+ direction: The direction in which to sort the values (`'ASCENDING'` or
+ `'DESCENDING'`).
+ stable: If True, equal elements in the original tensor will not be
+ re-ordered in the returned order. Unstable sort is not yet implemented,
+ but will eventually be the default for performance reasons. If you
+ require a stable order, pass `stable=True` for forwards compatibility.
+ name: Optional name for the operation.
+
+ Returns:
+ An int32 `Tensor` with the same shape as `values`. The indices that would
+ sort each slice of the given `values` along the given `axis`.
+
+ Raises:
+ ValueError: If axis is not a constant scalar, or the direction is invalid.
+ """
+ del stable # Unused.
+ with framework_ops.name_scope(name, 'argsort'):
+ return _sort_or_argsort(values, axis, direction, return_argsort=True)
+
+
+def _sort_or_argsort(values, axis, direction, return_argsort):
+ """Internal sort/argsort implementation.
+
+ Args:
+ values: The input values.
+ axis: The axis along which to sort.
+ direction: 'ASCENDING' or 'DESCENDING'.
+ return_argsort: Whether to return the argsort result.
+
+ Returns:
+ Either the sorted values, or the indices of the sorted values in the
+ original tensor. See the `sort` and `argsort` docstrings.
+
+ Raises:
+ ValueError: If axis is not a constant scalar, or the direction is invalid.
+ """
+ if direction not in _SORT_IMPL:
+ raise ValueError('%s should be one of %s' %
+ (direction, ', '.join(sorted(_SORT_IMPL.keys()))))
+ # Axis must be an integer, not a Tensor.
+ axis = framework_ops.convert_to_tensor(axis, name='axis')
+ axis_static = tensor_util.constant_value(axis)
+ if axis.shape.ndims != 0 or axis_static is None:
+ raise ValueError('axis must be a constant scalar')
+ axis_static = int(axis_static) # Avoids NumPy casting error
- values = framework_ops.convert_to_tensor(values, name='values')
+ values = framework_ops.convert_to_tensor(values, name='values')
- return _SORT_IMPL[direction](values, axis_static)
+ return _SORT_IMPL[direction](values, axis_static, return_argsort)
-def _descending_sort(values, axis):
+def _descending_sort(values, axis, return_argsort=False):
"""Sorts values in reverse using `top_k`.
Args:
values: Tensor of numeric values.
axis: Index of the axis which values should be sorted along.
+ return_argsort: If False, return the sorted values. If True, return the
+ indices that would sort the values.
Returns:
The sorted values.
"""
k = array_ops.shape(values)[axis]
rank = array_ops.rank(values)
+ static_rank = values.shape.ndims
# Fast path: sorting the last axis.
if axis == -1 or axis + 1 == values.get_shape().ndims:
- return nn_ops.top_k(values, k)[0]
-
- # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
- if axis < 0:
- # Make axis a Tensor with the real axis index if needed.
- axis += rank
- transposition = array_ops.concat(
- [
- # Axes up to axis are unchanged.
- math_ops.range(axis),
- # Swap axis and rank - 1.
- [rank - 1],
- # Axes in [axis + 1, rank - 1) are unchanged.
- math_ops.range(axis + 1, rank - 1),
- # Swap axis and rank - 1.
- [axis]
- ],
- axis=0)
- top_k_input = array_ops.transpose(values, transposition)
- values, unused_indices = nn_ops.top_k(top_k_input, k)
- # transposition contains a single cycle of length 2 (swapping 2 elements),
- # so it is an involution (it is its own inverse).
- return array_ops.transpose(values, transposition)
-
-
-def _ascending_sort(values, axis):
+ top_k_input = values
+ transposition = None
+ else:
+ # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
+ if axis < 0:
+ # Calculate the actual axis index if counting from the end. Use the static
+ # rank if available, or else make the axis back into a tensor.
+ axis += static_rank or rank
+ if static_rank is not None:
+ # Prefer to calculate the transposition array in NumPy and make it a
+ # constant.
+ transposition = constant_op.constant(
+ np.r_[
+ # Axes up to axis are unchanged.
+ np.arange(axis),
+ # Swap axis and rank - 1.
+ [static_rank - 1],
+ # Axes in [axis + 1, rank - 1) are unchanged.
+ np.arange(axis + 1, static_rank - 1),
+ # Swap axis and rank - 1.
+ [axis]],
+ name='transposition')
+ else:
+ # Generate the transposition array from the tensors.
+ transposition = array_ops.concat(
+ [
+ # Axes up to axis are unchanged.
+ math_ops.range(axis),
+ # Swap axis and rank - 1.
+ [rank - 1],
+ # Axes in [axis + 1, rank - 1) are unchanged.
+ math_ops.range(axis + 1, rank - 1),
+ # Swap axis and rank - 1.
+ [axis]
+ ],
+ axis=0)
+ top_k_input = array_ops.transpose(values, transposition)
+
+ values, indices = nn_ops.top_k(top_k_input, k)
+ return_value = indices if return_argsort else values
+ if transposition is not None:
+ # transposition contains a single cycle of length 2 (swapping 2 elements),
+ # so it is an involution (it is its own inverse).
+ return_value = array_ops.transpose(return_value, transposition)
+ return return_value
+
+
+def _ascending_sort(values, axis, return_argsort=False):
# Negate the values to get the ascending order from descending sort.
- values_or_indices = _descending_sort(-values, axis)
- return -values_or_indices
+ values_or_indices = _descending_sort(-values, axis, return_argsort)
+ # If not argsort, negate the values again.
+ return values_or_indices if return_argsort else -values_or_indices
_SORT_IMPL = {
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
index d08ae502f1..a8fb94b245 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
@@ -24,6 +24,8 @@ from tensorflow.contrib.framework.python.ops import sort_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -90,6 +92,38 @@ class SortTest(test.TestCase):
axis=0,
direction='DESCENDING').eval())
+ def testSort_staticallyKnownRank_constantTransposition(self):
+ # The transposition array should be a constant if the rank of "values" is
+ # statically known.
+ tensor = random_ops.random_uniform(
+ # Rank is statically known to be 5, but the dimension lengths are not
+ # known.
+ random_ops.random_uniform(
+ shape=(5,), minval=0, maxval=10, dtype=dtypes.int32))
+ sort_ops.sort(tensor, axis=1)
+ transposition = (
+ ops.get_default_graph().get_tensor_by_name('sort/transposition:0'))
+ self.assertFalse(tensor_util.constant_value(transposition) is None)
+ self.assertAllEqual(
+ # Swaps "1" and "4" to put "1" at the end.
+ tensor_util.constant_value(transposition),
+ [0, 4, 2, 3, 1])
+
+ def testArgsort_1d(self):
+ arr = np.random.random(42)
+ with self.test_session():
+ self.assertAllEqual(
+ np.sort(arr),
+ array_ops.gather(arr, sort_ops.argsort(arr)).eval())
+
+ def testArgsort(self):
+ arr = np.random.random((5, 6, 7, 8))
+ for axis in range(4):
+ with self.test_session():
+ self.assertAllEqual(
+ np.argsort(arr, axis=axis),
+ sort_ops.argsort(arr, axis=axis).eval())
+
if __name__ == '__main__':
test.main()