aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py57
1 files changed, 48 insertions, 9 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index b6e459c27f..ad05f823fb 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -291,16 +291,20 @@ def _InTopKShape(op):
@ops.RegisterShape("TopK")
+@ops.RegisterShape("TopKV2")
def _TopKShape(op):
- """Shape function for TopK op."""
- input_shape = op.inputs[0].get_shape().with_rank(2)
- k = op.get_attr("k")
- num_rows = input_shape[0]
- num_cols = input_shape[1]
- if num_cols.value is not None and num_cols.value < k:
- raise ValueError("input must have at least k (%d) columns" % k)
- return [tensor_shape.TensorShape([num_rows, k]),
- tensor_shape.TensorShape([num_rows, k])]
+ """Shape function for TopK and TopKV2 ops."""
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
+ if len(op.inputs) >= 2:
+ k = tensor_util.ConstantValue(op.inputs[1])
+ else:
+ k = op.get_attr("k")
+ last = input_shape[-1].value
+ if last is not None and last < k:
+ raise ValueError("input.shape %s must have last dimension >= k = %d" %
+ (input_shape, k))
+ output_shape = input_shape[:-1].concatenate([k])
+ return [output_shape, output_shape]
@ops.RegisterShape("BatchNormWithGlobalNormalization")
@@ -470,4 +474,39 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
ret.set_shape(x.get_shape())
return ret
+
+def top_k(input, k=1, sorted=True, name=None):
+ """Finds values and indices of the `k` largest entries for the last dimension.
+
+ If the input is a vector (rank-1), finds the `k` largest entries in the vector
+ and outputs their values and indices as vectors. Thus `values[j]` is the
+ `j`-th largest entry in `input`, and its index is `indices[j]`.
+
+ For matrices (resp. higher rank input), computes the top `k` entries in each
+ row (resp. vector along the last dimension). Thus,
+
+ values.shape = indices.shape = input.shape[:-1] + [k]
+
+ If two elements are equal, the lower-index element appears first.
+
+ Args:
+ input: 1-D or higher `Tensor` with last dimension at least `k`.
+ k: 0-D `int32` `Tensor`. Number of top elements to look for along the last
+ dimension (along each row for matrices).
+ sorted: If true the resulting `k` elements will be sorted by the values in
+ descending order.
+ name: Optional name for the operation.
+
+ Returns:
+ values: The `k` largest elements along each last dimensional slice.
+ indices: The indices of `values` within the last dimension of `input`.
+ """
+ # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck.
+ if isinstance(k, ops.Tensor):
+ op = gen_nn_ops._top_kv2
+ else:
+ op = gen_nn_ops._top_k
+ return op(input, k=k, sorted=sorted, name=name)
+
+
# pylint: enable=invalid-name