aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-01-15 17:06:09 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-15 18:24:48 -0800
commit2302df7c73dac7bcc3dc3a69baa00498819fb23d (patch)
tree655acb4c6db2c63886e82bffda68397fbbe10748
parentbd3af957185f2835c638a045c80d6c11b7f71df1 (diff)
Generalize top_k to any rank >= 1 and Tensor k
Previously, top_k only worked for matrices, and required the k value to be known at Graph construction time. This CL lifts both restrictions. Since changing an attr to an input is a backwards incompatible change, TopK still has an attr and there is a new TopKV2 that takes an input. Since the GraphDef versioning mechanism is in the middle of a redesign, I haven't used OP_DEPRECATE on the old TopK version just yet. Instead, the Python wrapper invokes the new one only if the input is a Tensor. This is temporary. Change: 112297740
-rw-r--r--tensorflow/core/kernels/topk_op.cc63
-rw-r--r--tensorflow/core/ops/nn_ops.cc63
-rw-r--r--tensorflow/core/ops/ops.pbtxt61
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py25
-rw-r--r--tensorflow/python/ops/nn_ops.py57
7 files changed, 216 insertions, 57 deletions
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
index a96a633800..374f67afa7 100644
--- a/tensorflow/core/kernels/topk_op.cc
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -30,36 +30,52 @@ template <typename T>
class TopK : public OpKernel {
public:
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
-
- if (k_ == 1) {
- sorted_ = false;
+ if (num_inputs() < 2) { // k is an attr (TopK).
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ } else { // k is an input (TopKV2), so we won't know it until Compute.
+ k_ = -1;
}
}
void Compute(OpKernelContext* context) override {
+ int k = k_;
+ if (num_inputs() >= 2) {
+ const auto& k_in = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
+ errors::InvalidArgument("k must be scalar, got shape ",
+ k_in.shape().ShortDebugString()));
+ k = k_in.scalar<int32>()();
+ }
+ OP_REQUIRES(context, k >= 0,
+ errors::InvalidArgument("Need k >= 0, got ", k));
const auto& input_in = context->input(0);
- OP_REQUIRES(context, input_in.dims() == 2,
- errors::InvalidArgument("input must be 2-dimensional"));
- OP_REQUIRES(context, input_in.dim_size(1) >= k_,
+ OP_REQUIRES(context, input_in.dims() >= 1,
+ errors::InvalidArgument("input must be >= 1-D, got shape ",
+ input_in.shape().ShortDebugString()));
+ OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
errors::InvalidArgument("input must have at least k columns"));
- const auto& input = input_in.matrix<T>();
+ const auto& input = input_in.flat_inner_dims<T>();
- const auto num_rows = input_in.dim_size(0); // generally batch_size
- const auto num_cols = input_in.dim_size(1);
+ const auto num_rows = input.dimension(0); // generally batch_size
+ const auto num_cols = input.dimension(1);
+ TensorShape output_shape = input_in.shape();
+ output_shape.set_dim(input_in.dims() - 1, k);
Tensor* values_out = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({num_rows, k_}), &values_out));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &values_out));
Tensor* indices_out = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 1, TensorShape({num_rows, k_}), &indices_out));
- auto values = values_out->matrix<T>();
- auto indices = indices_out->matrix<int32>();
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, output_shape, &indices_out));
- gtl::TopN<std::pair<T, int32>> filter(k_);
+ // Nothing to do for top-nothing.
+ if (k == 0) return;
+
+ auto values = values_out->flat_inner_dims<T>();
+ auto indices = indices_out->flat_inner_dims<int32>();
+ gtl::TopN<std::pair<T, int32>> filter(k);
for (int r = 0; r < num_rows; r++) {
for (int32 c = 0; c < num_cols; ++c) {
// The second element is the negated index, so that lower-index elements
@@ -68,7 +84,7 @@ class TopK : public OpKernel {
}
int32 i = 0;
- if (sorted_) {
+ if (sorted_ && k > 1) {
std::unique_ptr<std::vector<std::pair<T, int32>>> top_k(
filter.Extract());
for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
@@ -92,11 +108,16 @@ class TopK : public OpKernel {
bool sorted_;
};
-#define REGISTER_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("TopK").Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
+#define REGISTER_KERNELS_NAME(name, type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNELS_NAME(TopK, type); \
+ REGISTER_KERNELS_NAME(TopKV2, type)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS_TO_NAME
#undef REGISTER_KERNELS
} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 3cf2748c07..8c6a274cbb 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -585,30 +585,67 @@ precision: Computed Precision at `k` as a `bool Tensor`.
)doc");
REGISTER_OP("TopK")
- .Attr("k: int >= 1")
- .Attr("sorted: bool = true")
.Input("input: T")
.Output("values: T")
.Output("indices: int32")
+ .Attr("k: int >= 0")
+ .Attr("sorted: bool = true")
.Attr("T: realnumbertype")
.Doc(R"doc(
-Returns the values and indices of the `k` largest elements for each row.
+Finds values and indices of the `k` largest elements for the last dimension.
+
+If the input is a vector (rank-1), finds the `k` largest entries in the vector
+and outputs their values and indices as vectors. Thus `values[j]` is the
+`j`-th largest entry in `input`, and its index is `indices[j]`.
+
+For matrices (resp. higher rank input), computes the top `k` entries in each
+row (resp. vector along the last dimension). Thus,
+
+ values.shape = indices.shape = input.shape[:-1] + [k]
-\\(values_{i, j}\\) represents the j-th largest element in \\(input_i\\).
+If two elements are equal, the lower-index element appears first.
-\\(indices_{i, j}\\) gives the column index of the corresponding element,
-such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two
-elements are equal, the lower-index element appears first.
+If `k` varies dynamically, use `TopKV2` below.
-k: Number of top elements to look for within each row.
+input: 1-D or higher with last dimension at least `k`.
+k: Number of top elements to look for along the last dimension (along each
+ row for matrices).
sorted: If true the resulting `k` elements will be sorted by the values in
descending order.
-input: A `batch_size` x `classes` tensor.
-values: A `batch_size` x `k` tensor with the `k` largest elements for
- each row.
-indices: A `batch_size` x `k` tensor with the index of each value within
- each row.
+values: The `k` largest elements along each last dimensional slice.
+indices: The indices of `values` within the last dimension of `input`.
+)doc");
+
+REGISTER_OP("TopKV2")
+ .Input("input: T")
+ .Input("k: int32")
+ .Output("values: T")
+ .Output("indices: int32")
+ .Attr("sorted: bool = true")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Finds values and indices of the `k` largest elements for the last dimension.
+
+If the input is a vector (rank-1), finds the `k` largest entries in the vector
+and outputs their values and indices as vectors. Thus `values[j]` is the
+`j`-th largest entry in `input`, and its index is `indices[j]`.
+
+For matrices (resp. higher rank input), computes the top `k` entries in each
+row (resp. vector along the last dimension). Thus,
+ values.shape = indices.shape = input.shape[:-1] + [k]
+
+If two elements are equal, the lower-index element appears first.
+
+This is the same as `TopK`, but takes `k` as in input rather than an attr.
+
+input: 1-D or higher with last dimension at least `k`.
+k: 0-D. Number of top elements to look for along the last dimension (along each
+ row for matrices).
+sorted: If true the resulting `k` elements will be sorted by the values in
+ descending order.
+values: The `k` largest elements along each last dimensional slice.
+indices: The indices of `values` within the last dimension of `input`.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 90275c8afc..263bd225ed 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -8894,25 +8894,24 @@ op {
name: "TopK"
input_arg {
name: "input"
- description: "A `batch_size` x `classes` tensor."
+ description: "1-D or higher with last dimension at least `k`."
type_attr: "T"
}
output_arg {
name: "values"
- description: "A `batch_size` x `k` tensor with the `k` largest elements for\neach row."
+ description: "The `k` largest elements along each last dimensional slice."
type_attr: "T"
}
output_arg {
name: "indices"
- description: "A `batch_size` x `k` tensor with the index of each value within\neach row."
+ description: "The indices of `values` within the last dimension of `input`."
type: DT_INT32
}
attr {
name: "k"
type: "int"
- description: "Number of top elements to look for within each row."
+ description: "Number of top elements to look for along the last dimension (along each\nrow for matrices)."
has_minimum: true
- minimum: 1
}
attr {
name: "sorted"
@@ -8937,8 +8936,56 @@ op {
}
}
}
- summary: "Returns the values and indices of the `k` largest elements for each row."
- description: "\\\\(values_{i, j}\\\\) represents the j-th largest element in \\\\(input_i\\\\).\n\n\\\\(indices_{i, j}\\\\) gives the column index of the corresponding element,\nsuch that \\\\(input_{i, indices_{i, j}} = values_{i, j}\\\\). If two\nelements are equal, the lower-index element appears first."
+ summary: "Finds values and indices of the `k` largest elements for the last dimension."
+ description: "If the input is a vector (rank-1), finds the `k` largest entries in the vector\nand outputs their values and indices as vectors. Thus `values[j]` is the\n`j`-th largest entry in `input`, and its index is `indices[j]`.\n\nFor matrices (resp. higher rank input), computes the top `k` entries in each\nrow (resp. vector along the last dimension). Thus,\n\n values.shape = indices.shape = input.shape[:-1] + [k]\n\nIf two elements are equal, the lower-index element appears first.\n\nIf `k` varies dynamically, use `TopKV2` below."
+}
+op {
+ name: "TopKV2"
+ input_arg {
+ name: "input"
+ description: "1-D or higher with last dimension at least `k`."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "k"
+ description: "0-D. Number of top elements to look for along the last dimension (along each\nrow for matrices)."
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ description: "The `k` largest elements along each last dimensional slice."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "indices"
+ description: "The indices of `values` within the last dimension of `input`."
+ type: DT_INT32
+ }
+ attr {
+ name: "sorted"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ description: "If true the resulting `k` elements will be sorted by the values in\ndescending order."
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ }
+ }
+ }
+ summary: "Finds values and indices of the `k` largest elements for the last dimension."
+ description: "If the input is a vector (rank-1), finds the `k` largest entries in the vector\nand outputs their values and indices as vectors. Thus `values[j]` is the\n`j`-th largest entry in `input`, and its index is `indices[j]`.\n\nFor matrices (resp. higher rank input), computes the top `k` entries in each\nrow (resp. vector along the last dimension). Thus,\n\n values.shape = indices.shape = input.shape[:-1] + [k]\n\nIf two elements are equal, the lower-index element appears first.\n\nThis is the same as `TopK`, but takes `k` as in input rather than an attr."
}
op {
name: "Transpose"
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index d00085951a..a967af468e 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -53,7 +53,7 @@ limitations under the License.
// 5. Graphs are wholly-validated during Session::Create() (7jan2016).
// 6. TensorFlow is scalar strict within Google (current on or after 1feb2016).
#define TF_GRAPH_DEF_VERSION_MIN 0
-#define TF_GRAPH_DEF_VERSION_MAX 6
+#define TF_GRAPH_DEF_VERSION_MAX 7
#define TF_GRAPH_DEF_VERSION 5
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7e4cf0c913..be5f788b31 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -561,6 +561,8 @@ tf_gen_op_wrapper_py(
"EluGrad",
"SoftplusGrad",
"SoftsignGrad",
+ "TopK",
+ "TopKV2",
"BiasAdd",
"Relu6",
"AvgPool",
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index 184c15cc66..6a25edd48e 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -27,11 +27,11 @@ import tensorflow as tf
class TopKTest(tf.test.TestCase):
def _validateTopK(
- self, inputs, k, expected_values, expected_indices, sorted_output=True):
+ self, inputs, k, expected_values, expected_indices, sorted=True):
np_values = np.array(expected_values)
np_indices = np.array(expected_indices)
with self.test_session():
- values_op, indices_op = tf.nn.top_k(inputs, k, sorted=sorted_output)
+ values_op, indices_op = tf.nn.top_k(inputs, k, sorted=sorted)
values = values_op.eval()
indices = indices_op.eval()
self.assertAllClose(np_values, values)
@@ -61,16 +61,29 @@ class TopKTest(tf.test.TestCase):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
self._validateTopK(inputs, 3,
[[0.2, 0.3, 0.4], [0.2, 0.3, 0.3]],
- [[2, 1, 3], [3, 1, 2]], sorted_output=False)
+ [[2, 1, 3], [3, 1, 2]], sorted=False)
+
+ def testTop3Vector(self):
+ inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6]
+ self._validateTopK(inputs, 3, [19, 18, 17], [11, 3, 7])
+
+ def testTensorK(self):
+ inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6]
+ k = tf.constant(3)
+ self._validateTopK(inputs, k, [19, 18, 17], [11, 3, 7])
def testKNegative(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
- with self.assertRaisesRegexp(ValueError, "less than minimum 1"):
- tf.nn.top_k(inputs, -1)
+ with self.test_session():
+ k = tf.placeholder(tf.int32)
+ values, _ = tf.nn.top_k(inputs, k)
+ with self.assertRaisesOpError("Need k >= 0, got -7"):
+ values.eval(feed_dict={k: -7})
def testKTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
- with self.assertRaisesRegexp(ValueError, "input must have at least k"):
+ with self.assertRaisesRegexp(
+ ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"):
tf.nn.top_k(inputs, 4)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index b6e459c27f..ad05f823fb 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -291,16 +291,20 @@ def _InTopKShape(op):
@ops.RegisterShape("TopK")
+@ops.RegisterShape("TopKV2")
def _TopKShape(op):
- """Shape function for TopK op."""
- input_shape = op.inputs[0].get_shape().with_rank(2)
- k = op.get_attr("k")
- num_rows = input_shape[0]
- num_cols = input_shape[1]
- if num_cols.value is not None and num_cols.value < k:
- raise ValueError("input must have at least k (%d) columns" % k)
- return [tensor_shape.TensorShape([num_rows, k]),
- tensor_shape.TensorShape([num_rows, k])]
+ """Shape function for TopK and TopKV2 ops."""
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
+ if len(op.inputs) >= 2:
+ k = tensor_util.ConstantValue(op.inputs[1])
+ else:
+ k = op.get_attr("k")
+ last = input_shape[-1].value
+ if last is not None and last < k:
+ raise ValueError("input.shape %s must have last dimension >= k = %d" %
+ (input_shape, k))
+ output_shape = input_shape[:-1].concatenate([k])
+ return [output_shape, output_shape]
@ops.RegisterShape("BatchNormWithGlobalNormalization")
@@ -470,4 +474,39 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
ret.set_shape(x.get_shape())
return ret
+
+def top_k(input, k=1, sorted=True, name=None):
+ """Finds values and indices of the `k` largest entries for the last dimension.
+
+ If the input is a vector (rank-1), finds the `k` largest entries in the vector
+ and outputs their values and indices as vectors. Thus `values[j]` is the
+ `j`-th largest entry in `input`, and its index is `indices[j]`.
+
+ For matrices (resp. higher rank input), computes the top `k` entries in each
+ row (resp. vector along the last dimension). Thus,
+
+ values.shape = indices.shape = input.shape[:-1] + [k]
+
+ If two elements are equal, the lower-index element appears first.
+
+ Args:
+ input: 1-D or higher `Tensor` with last dimension at least `k`.
+ k: 0-D `int32` `Tensor`. Number of top elements to look for along the last
+ dimension (along each row for matrices).
+ sorted: If true the resulting `k` elements will be sorted by the values in
+ descending order.
+ name: Optional name for the operation.
+
+ Returns:
+ values: The `k` largest elements along each last dimensional slice.
+ indices: The indices of `values` within the last dimension of `input`.
+ """
+ # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck.
+ if isinstance(k, ops.Tensor):
+ op = gen_nn_ops._top_kv2
+ else:
+ op = gen_nn_ops._top_k
+ return op(input, k=k, sorted=sorted, name=name)
+
+
# pylint: enable=invalid-name