diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 57 |
1 files changed, 48 insertions, 9 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index b6e459c27f..ad05f823fb 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -291,16 +291,20 @@ def _InTopKShape(op): @ops.RegisterShape("TopK") +@ops.RegisterShape("TopKV2") def _TopKShape(op): - """Shape function for TopK op.""" - input_shape = op.inputs[0].get_shape().with_rank(2) - k = op.get_attr("k") - num_rows = input_shape[0] - num_cols = input_shape[1] - if num_cols.value is not None and num_cols.value < k: - raise ValueError("input must have at least k (%d) columns" % k) - return [tensor_shape.TensorShape([num_rows, k]), - tensor_shape.TensorShape([num_rows, k])] + """Shape function for TopK and TopKV2 ops.""" + input_shape = op.inputs[0].get_shape().with_rank_at_least(1) + if len(op.inputs) >= 2: + k = tensor_util.ConstantValue(op.inputs[1]) + else: + k = op.get_attr("k") + last = input_shape[-1].value + if last is not None and last < k: + raise ValueError("input.shape %s must have last dimension >= k = %d" % + (input_shape, k)) + output_shape = input_shape[:-1].concatenate([k]) + return [output_shape, output_shape] @ops.RegisterShape("BatchNormWithGlobalNormalization") @@ -470,4 +474,39 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): ret.set_shape(x.get_shape()) return ret + +def top_k(input, k=1, sorted=True, name=None): + """Finds values and indices of the `k` largest entries for the last dimension. + + If the input is a vector (rank-1), finds the `k` largest entries in the vector + and outputs their values and indices as vectors. Thus `values[j]` is the + `j`-th largest entry in `input`, and its index is `indices[j]`. + + For matrices (resp. higher rank input), computes the top `k` entries in each + row (resp. vector along the last dimension). Thus, + + values.shape = indices.shape = input.shape[:-1] + [k] + + If two elements are equal, the lower-index element appears first. + + Args: + input: 1-D or higher `Tensor` with last dimension at least `k`. + k: 0-D `int32` `Tensor`. Number of top elements to look for along the last + dimension (along each row for matrices). + sorted: If true the resulting `k` elements will be sorted by the values in + descending order. + name: Optional name for the operation. + + Returns: + values: The `k` largest elements along each last dimensional slice. + indices: The indices of `values` within the last dimension of `input`. + """ + # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck. + if isinstance(k, ops.Tensor): + op = gen_nn_ops._top_kv2 + else: + op = gen_nn_ops._top_k + return op(input, k=k, sorted=sorted, name=name) + + # pylint: enable=invalid-name |