1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
// See docs in ../ops/nn_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
template <typename T>
class TopK : public OpKernel {
public:
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
}
void Compute(OpKernelContext* context) override {
const auto& input_in = context->input(0);
OP_REQUIRES(context, input_in.dims() == 2,
errors::InvalidArgument("input must be 2-dimensional"));
OP_REQUIRES(context, input_in.dim_size(1) >= k_,
errors::InvalidArgument("input must have at least k columns"));
const auto& input = input_in.matrix<T>();
const auto num_rows = input_in.dim_size(0); // generally batch_size
const auto num_cols = input_in.dim_size(1);
Tensor* values_out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({num_rows, k_}), &values_out));
Tensor* indices_out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
1, TensorShape({num_rows, k_}), &indices_out));
auto values = values_out->matrix<T>();
auto indices = indices_out->matrix<int32>();
gtl::TopN<std::pair<T, int32>> filter(k_);
for (int r = 0; r < num_rows; r++) {
for (int32 c = 0; c < num_cols; ++c) {
// The second element is the negated index, so that lower-index elements
// are considered larger than higher-index elements in case of ties.
filter.push(std::make_pair(input(r, c), -c));
}
std::unique_ptr<std::vector<std::pair<T, int32>>> top_k(filter.Extract());
for (int32 i = 0; i < k_; ++i) {
values(r, i) = (*top_k)[i].first;
indices(r, i) = -(*top_k)[i].second;
}
filter.Reset();
}
}
private:
int k_;
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("TopK").Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
} // namespace tensorflow
|