diff options
author | 2016-01-15 17:06:09 -0800 | |
---|---|---|
committer | 2016-01-15 18:24:48 -0800 | |
commit | 2302df7c73dac7bcc3dc3a69baa00498819fb23d (patch) | |
tree | 655acb4c6db2c63886e82bffda68397fbbe10748 | |
parent | bd3af957185f2835c638a045c80d6c11b7f71df1 (diff) |
Generalize top_k to any rank >= 1 and Tensor k
Previously, top_k only worked for matrices, and required the k value to be
known at Graph construction time. This CL lifts both restrictions. Since
changing an attr to an input is a backwards incompatible change, TopK still has
an attr and there is a new TopKV2 that takes an input.
Since the GraphDef versioning mechanism is in the middle of a redesign, I
haven't used OP_DEPRECATE on the old TopK version just yet. Instead, the
Python wrapper invokes the new one only if the input is a Tensor. This is
temporary.
Change: 112297740
-rw-r--r-- | tensorflow/core/kernels/topk_op.cc | 63 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 63 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 61 | ||||
-rw-r--r-- | tensorflow/core/public/version.h | 2 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/topk_op_test.py | 25 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 57 |
7 files changed, 216 insertions, 57 deletions
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index a96a633800..374f67afa7 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -30,36 +30,52 @@ template <typename T> class TopK : public OpKernel { public: explicit TopK(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); - - if (k_ == 1) { - sorted_ = false; + if (num_inputs() < 2) { // k is an attr (TopK). + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } else { // k is an input (TopKV2), so we won't know it until Compute. + k_ = -1; } } void Compute(OpKernelContext* context) override { + int k = k_; + if (num_inputs() >= 2) { + const auto& k_in = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), + errors::InvalidArgument("k must be scalar, got shape ", + k_in.shape().ShortDebugString())); + k = k_in.scalar<int32>()(); + } + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); const auto& input_in = context->input(0); - OP_REQUIRES(context, input_in.dims() == 2, - errors::InvalidArgument("input must be 2-dimensional")); - OP_REQUIRES(context, input_in.dim_size(1) >= k_, + OP_REQUIRES(context, input_in.dims() >= 1, + errors::InvalidArgument("input must be >= 1-D, got shape ", + input_in.shape().ShortDebugString())); + OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, errors::InvalidArgument("input must have at least k columns")); - const auto& input = input_in.matrix<T>(); + const auto& input = input_in.flat_inner_dims<T>(); - const auto num_rows = input_in.dim_size(0); // generally batch_size - const auto num_cols = input_in.dim_size(1); + const auto num_rows = input.dimension(0); // generally batch_size + const auto num_cols = input.dimension(1); + TensorShape output_shape = input_in.shape(); + output_shape.set_dim(input_in.dims() - 1, k); Tensor* values_out = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({num_rows, k_}), &values_out)); + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &values_out)); Tensor* indices_out = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 1, TensorShape({num_rows, k_}), &indices_out)); - auto values = values_out->matrix<T>(); - auto indices = indices_out->matrix<int32>(); + OP_REQUIRES_OK(context, + context->allocate_output(1, output_shape, &indices_out)); - gtl::TopN<std::pair<T, int32>> filter(k_); + // Nothing to do for top-nothing. + if (k == 0) return; + + auto values = values_out->flat_inner_dims<T>(); + auto indices = indices_out->flat_inner_dims<int32>(); + gtl::TopN<std::pair<T, int32>> filter(k); for (int r = 0; r < num_rows; r++) { for (int32 c = 0; c < num_cols; ++c) { // The second element is the negated index, so that lower-index elements @@ -68,7 +84,7 @@ class TopK : public OpKernel { } int32 i = 0; - if (sorted_) { + if (sorted_ && k > 1) { std::unique_ptr<std::vector<std::pair<T, int32>>> top_k( filter.Extract()); for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); @@ -92,11 +108,16 @@ class TopK : public OpKernel { bool sorted_; }; -#define REGISTER_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("TopK").Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>) +#define REGISTER_KERNELS_NAME(name, type) \ + REGISTER_KERNEL_BUILDER( \ + Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>) + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNELS_NAME(TopK, type); \ + REGISTER_KERNELS_NAME(TopKV2, type) TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS_TO_NAME #undef REGISTER_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 3cf2748c07..8c6a274cbb 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -585,30 +585,67 @@ precision: Computed Precision at `k` as a `bool Tensor`. )doc"); REGISTER_OP("TopK") - .Attr("k: int >= 1") - .Attr("sorted: bool = true") .Input("input: T") .Output("values: T") .Output("indices: int32") + .Attr("k: int >= 0") + .Attr("sorted: bool = true") .Attr("T: realnumbertype") .Doc(R"doc( -Returns the values and indices of the `k` largest elements for each row. +Finds values and indices of the `k` largest elements for the last dimension. + +If the input is a vector (rank-1), finds the `k` largest entries in the vector +and outputs their values and indices as vectors. Thus `values[j]` is the +`j`-th largest entry in `input`, and its index is `indices[j]`. + +For matrices (resp. higher rank input), computes the top `k` entries in each +row (resp. vector along the last dimension). Thus, + + values.shape = indices.shape = input.shape[:-1] + [k] -\\(values_{i, j}\\) represents the j-th largest element in \\(input_i\\). +If two elements are equal, the lower-index element appears first. -\\(indices_{i, j}\\) gives the column index of the corresponding element, -such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two -elements are equal, the lower-index element appears first. +If `k` varies dynamically, use `TopKV2` below. -k: Number of top elements to look for within each row. +input: 1-D or higher with last dimension at least `k`. +k: Number of top elements to look for along the last dimension (along each + row for matrices). sorted: If true the resulting `k` elements will be sorted by the values in descending order. -input: A `batch_size` x `classes` tensor. -values: A `batch_size` x `k` tensor with the `k` largest elements for - each row. -indices: A `batch_size` x `k` tensor with the index of each value within - each row. +values: The `k` largest elements along each last dimensional slice. +indices: The indices of `values` within the last dimension of `input`. +)doc"); + +REGISTER_OP("TopKV2") + .Input("input: T") + .Input("k: int32") + .Output("values: T") + .Output("indices: int32") + .Attr("sorted: bool = true") + .Attr("T: realnumbertype") + .Doc(R"doc( +Finds values and indices of the `k` largest elements for the last dimension. + +If the input is a vector (rank-1), finds the `k` largest entries in the vector +and outputs their values and indices as vectors. Thus `values[j]` is the +`j`-th largest entry in `input`, and its index is `indices[j]`. + +For matrices (resp. higher rank input), computes the top `k` entries in each +row (resp. vector along the last dimension). Thus, + values.shape = indices.shape = input.shape[:-1] + [k] + +If two elements are equal, the lower-index element appears first. + +This is the same as `TopK`, but takes `k` as in input rather than an attr. + +input: 1-D or higher with last dimension at least `k`. +k: 0-D. Number of top elements to look for along the last dimension (along each + row for matrices). +sorted: If true the resulting `k` elements will be sorted by the values in + descending order. +values: The `k` largest elements along each last dimensional slice. +indices: The indices of `values` within the last dimension of `input`. )doc"); } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 90275c8afc..263bd225ed 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -8894,25 +8894,24 @@ op { name: "TopK" input_arg { name: "input" - description: "A `batch_size` x `classes` tensor." + description: "1-D or higher with last dimension at least `k`." type_attr: "T" } output_arg { name: "values" - description: "A `batch_size` x `k` tensor with the `k` largest elements for\neach row." + description: "The `k` largest elements along each last dimensional slice." type_attr: "T" } output_arg { name: "indices" - description: "A `batch_size` x `k` tensor with the index of each value within\neach row." + description: "The indices of `values` within the last dimension of `input`." type: DT_INT32 } attr { name: "k" type: "int" - description: "Number of top elements to look for within each row." + description: "Number of top elements to look for along the last dimension (along each\nrow for matrices)." has_minimum: true - minimum: 1 } attr { name: "sorted" @@ -8937,8 +8936,56 @@ op { } } } - summary: "Returns the values and indices of the `k` largest elements for each row." - description: "\\\\(values_{i, j}\\\\) represents the j-th largest element in \\\\(input_i\\\\).\n\n\\\\(indices_{i, j}\\\\) gives the column index of the corresponding element,\nsuch that \\\\(input_{i, indices_{i, j}} = values_{i, j}\\\\). If two\nelements are equal, the lower-index element appears first." + summary: "Finds values and indices of the `k` largest elements for the last dimension." + description: "If the input is a vector (rank-1), finds the `k` largest entries in the vector\nand outputs their values and indices as vectors. Thus `values[j]` is the\n`j`-th largest entry in `input`, and its index is `indices[j]`.\n\nFor matrices (resp. higher rank input), computes the top `k` entries in each\nrow (resp. vector along the last dimension). Thus,\n\n values.shape = indices.shape = input.shape[:-1] + [k]\n\nIf two elements are equal, the lower-index element appears first.\n\nIf `k` varies dynamically, use `TopKV2` below." +} +op { + name: "TopKV2" + input_arg { + name: "input" + description: "1-D or higher with last dimension at least `k`." + type_attr: "T" + } + input_arg { + name: "k" + description: "0-D. Number of top elements to look for along the last dimension (along each\nrow for matrices)." + type: DT_INT32 + } + output_arg { + name: "values" + description: "The `k` largest elements along each last dimensional slice." + type_attr: "T" + } + output_arg { + name: "indices" + description: "The indices of `values` within the last dimension of `input`." + type: DT_INT32 + } + attr { + name: "sorted" + type: "bool" + default_value { + b: true + } + description: "If true the resulting `k` elements will be sorted by the values in\ndescending order." + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + } + } + } + summary: "Finds values and indices of the `k` largest elements for the last dimension." + description: "If the input is a vector (rank-1), finds the `k` largest entries in the vector\nand outputs their values and indices as vectors. Thus `values[j]` is the\n`j`-th largest entry in `input`, and its index is `indices[j]`.\n\nFor matrices (resp. higher rank input), computes the top `k` entries in each\nrow (resp. vector along the last dimension). Thus,\n\n values.shape = indices.shape = input.shape[:-1] + [k]\n\nIf two elements are equal, the lower-index element appears first.\n\nThis is the same as `TopK`, but takes `k` as in input rather than an attr." } op { name: "Transpose" diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d00085951a..a967af468e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -53,7 +53,7 @@ limitations under the License. // 5. Graphs are wholly-validated during Session::Create() (7jan2016). // 6. TensorFlow is scalar strict within Google (current on or after 1feb2016). #define TF_GRAPH_DEF_VERSION_MIN 0 -#define TF_GRAPH_DEF_VERSION_MAX 6 +#define TF_GRAPH_DEF_VERSION_MAX 7 #define TF_GRAPH_DEF_VERSION 5 #endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7e4cf0c913..be5f788b31 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -561,6 +561,8 @@ tf_gen_op_wrapper_py( "EluGrad", "SoftplusGrad", "SoftsignGrad", + "TopK", + "TopKV2", "BiasAdd", "Relu6", "AvgPool", diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index 184c15cc66..6a25edd48e 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -27,11 +27,11 @@ import tensorflow as tf class TopKTest(tf.test.TestCase): def _validateTopK( - self, inputs, k, expected_values, expected_indices, sorted_output=True): + self, inputs, k, expected_values, expected_indices, sorted=True): np_values = np.array(expected_values) np_indices = np.array(expected_indices) with self.test_session(): - values_op, indices_op = tf.nn.top_k(inputs, k, sorted=sorted_output) + values_op, indices_op = tf.nn.top_k(inputs, k, sorted=sorted) values = values_op.eval() indices = indices_op.eval() self.assertAllClose(np_values, values) @@ -61,16 +61,29 @@ class TopKTest(tf.test.TestCase): inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] self._validateTopK(inputs, 3, [[0.2, 0.3, 0.4], [0.2, 0.3, 0.3]], - [[2, 1, 3], [3, 1, 2]], sorted_output=False) + [[2, 1, 3], [3, 1, 2]], sorted=False) + + def testTop3Vector(self): + inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6] + self._validateTopK(inputs, 3, [19, 18, 17], [11, 3, 7]) + + def testTensorK(self): + inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6] + k = tf.constant(3) + self._validateTopK(inputs, k, [19, 18, 17], [11, 3, 7]) def testKNegative(self): inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.assertRaisesRegexp(ValueError, "less than minimum 1"): - tf.nn.top_k(inputs, -1) + with self.test_session(): + k = tf.placeholder(tf.int32) + values, _ = tf.nn.top_k(inputs, k) + with self.assertRaisesOpError("Need k >= 0, got -7"): + values.eval(feed_dict={k: -7}) def testKTooLarge(self): inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.assertRaisesRegexp(ValueError, "input must have at least k"): + with self.assertRaisesRegexp( + ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"): tf.nn.top_k(inputs, 4) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index b6e459c27f..ad05f823fb 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -291,16 +291,20 @@ def _InTopKShape(op): @ops.RegisterShape("TopK") +@ops.RegisterShape("TopKV2") def _TopKShape(op): - """Shape function for TopK op.""" - input_shape = op.inputs[0].get_shape().with_rank(2) - k = op.get_attr("k") - num_rows = input_shape[0] - num_cols = input_shape[1] - if num_cols.value is not None and num_cols.value < k: - raise ValueError("input must have at least k (%d) columns" % k) - return [tensor_shape.TensorShape([num_rows, k]), - tensor_shape.TensorShape([num_rows, k])] + """Shape function for TopK and TopKV2 ops.""" + input_shape = op.inputs[0].get_shape().with_rank_at_least(1) + if len(op.inputs) >= 2: + k = tensor_util.ConstantValue(op.inputs[1]) + else: + k = op.get_attr("k") + last = input_shape[-1].value + if last is not None and last < k: + raise ValueError("input.shape %s must have last dimension >= k = %d" % + (input_shape, k)) + output_shape = input_shape[:-1].concatenate([k]) + return [output_shape, output_shape] @ops.RegisterShape("BatchNormWithGlobalNormalization") @@ -470,4 +474,39 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): ret.set_shape(x.get_shape()) return ret + +def top_k(input, k=1, sorted=True, name=None): + """Finds values and indices of the `k` largest entries for the last dimension. + + If the input is a vector (rank-1), finds the `k` largest entries in the vector + and outputs their values and indices as vectors. Thus `values[j]` is the + `j`-th largest entry in `input`, and its index is `indices[j]`. + + For matrices (resp. higher rank input), computes the top `k` entries in each + row (resp. vector along the last dimension). Thus, + + values.shape = indices.shape = input.shape[:-1] + [k] + + If two elements are equal, the lower-index element appears first. + + Args: + input: 1-D or higher `Tensor` with last dimension at least `k`. + k: 0-D `int32` `Tensor`. Number of top elements to look for along the last + dimension (along each row for matrices). + sorted: If true the resulting `k` elements will be sorted by the values in + descending order. + name: Optional name for the operation. + + Returns: + values: The `k` largest elements along each last dimensional slice. + indices: The indices of `values` within the last dimension of `input`. + """ + # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck. + if isinstance(k, ops.Tensor): + op = gen_nn_ops._top_kv2 + else: + op = gen_nn_ops._top_k + return op(input, k=k, sorted=sorted, name=name) + + # pylint: enable=invalid-name |