aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/topk_op.cc
blob: 7fdce6cb7190ffa5f799853e27d18b9e33f2971a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// See docs in ../ops/nn_ops.cc.

#define EIGEN_USE_THREADS

#include "tensorflow/core/kernels/topk_op.h"

#include <algorithm>
#include <numeric>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/util/work_sharder.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

template <typename Device, typename T>
class TopK : public OpKernel {
 public:
  explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
    if (num_inputs() < 2) {  // k is an attr (TopK).
      OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
    } else {  // k is an input (TopKV2), so we won't know it until Compute.
      k_ = -1;
    }
  }

  void Compute(OpKernelContext* context) override {
    int k = k_;
    if (num_inputs() >= 2) {
      const auto& k_in = context->input(1);
      OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
                  errors::InvalidArgument("k must be scalar, got shape ",
                                          k_in.shape().DebugString()));
      k = k_in.scalar<int32>()();
    }
    OP_REQUIRES(context, k >= 0,
                errors::InvalidArgument("Need k >= 0, got ", k));
    const auto& input_in = context->input(0);
    OP_REQUIRES(context, input_in.dims() >= 1,
                errors::InvalidArgument("input must be >= 1-D, got shape ",
                                        input_in.shape().DebugString()));
    OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
                errors::InvalidArgument(
                    "input must have at least k columns. Had ",
                    input_in.dim_size(input_in.dims() - 1), ", needed ", k));

    const auto& input = input_in.flat_inner_dims<T>();

    const int64 num_rows = input.dimension(0);  // generally batch_size
    const int64 num_cols = input.dimension(1);

    TensorShape output_shape = input_in.shape();
    output_shape.set_dim(input_in.dims() - 1, k);
    Tensor* values_out = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, output_shape, &values_out));
    Tensor* indices_out = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(1, output_shape, &indices_out));

    // Nothing to do for top-nothing.
    if (k == 0) return;

    auto values = values_out->flat_inner_dims<T>();
    auto indices = indices_out->flat_inner_dims<int32>();
    Status s = functor::TopKFunctor<Device, T>::Compute(
        context, sorted_, k, input, num_rows, num_cols, values, indices);
    OP_REQUIRES_OK(context, s);
  }

 private:
  int k_;
  bool sorted_;
};

namespace functor {

template <typename T>
struct TopKFunctor<CPUDevice, T> {
  static EIGEN_ALWAYS_INLINE Status
  Compute(OpKernelContext* context, bool sorted, int k,
          const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
          const int64 num_cols, typename TTypes<T, 2>::Tensor values,
          typename TTypes<int, 2>::Tensor indices) {
    const CPUDevice& d = context->eigen_device<CPUDevice>();

    // Special case for k == 1.
    if (k == 1) {
#ifdef EIGEN_HAS_INDEX_LIST
      typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
      typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
      rows_by_one.set(0, num_rows);
#else
      Eigen::array<int, 1> reduce_on_cols = {1};
      Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
#endif

      values.device(d) =
          input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
      // Get the indices of the maximum values.
      for (int r = 0; r < num_rows; ++r) {
        for (int c = 0; c < num_cols; ++c) {
          if (values(r, 0) == input(r, c)) {
            indices(r, 0) = c;
            break;
          }
        }
      }

      return Status::OK();
    }

    auto SortIndices = [&, context](int start_batch, int limit_batch) {
      for (int32 b = start_batch; b < limit_batch; ++b) {
        const T* input_data = &input(b, 0);
        const auto stable_comp = [input_data](const int32 a, const int32 b) {
          if (input_data[b] < input_data[a]) {
            return true;
          } else if (input_data[b] > input_data[a]) {
            return false;
          } else {
            return a < b;
          }
        };
        const auto comp = [input_data](const int32 a, const int32 b) {
          return input_data[b] < input_data[a];
        };
        // TODO(ebrevdo): For large k < num_cols, instead of using
        // TopN, it may be faster to create a temporary vector of
        // values 0..num_cols - 1 and then use std::partial_sort_copy
        // of this into indices. Choosing the appropriate minimum k or
        // ratio of k/num_cols will require some experimentation.
        if (k == num_cols) {
          auto* begin = &indices(b, 0);
          auto* end = &indices(b, k);
          // Set the initial array of indices 0 ... k - 1.
          std::iota(begin, end, 0);
          // We want an in-place sort, but we can cheat because we're sorting
          // indices that started out sorted.  First, do a std::sort, which
          // is notably faster than std::stable_sort.
          std::sort(begin, end, comp);
          // Then, for runs of adjacent elements that were equal, sort the
          // indices in those runs in increasing order.
          for (auto* run_begin = begin; run_begin != end;) {
            auto* run_end = run_begin + 1;
            if (run_end == end) break;
            if (input_data[*run_begin] == input_data[*run_end]) {
              while (++run_end != end) {
                if (input_data[*run_begin] != input_data[*run_end]) break;
              }
              std::sort(run_begin, run_end);
            }
            run_begin = run_end;
          }
        } else {
          // Use the TopN heap object to sort.
          gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
          filter.reserve(num_cols);
          for (int32 c = 0; c < num_cols; ++c) {
            filter.push(c);
          }

          int32 i = 0;
          if (sorted) {
            std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
            for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
                 ++top_k_it, ++i) {
              indices(b, i) = *top_k_it;
            }
          } else {
            for (auto top_k_it = filter.unsorted_begin();
                 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
              indices(b, i) = *top_k_it;
            }
          }
        }
        // Now that the indices are sorted, copy the values over in
        // sorted order.
        std::transform(&indices(b, 0), &indices(b, k), &values(b, 0),
                       [b, &input](const int32 loc) { return input(b, loc); });
      }  // for (int32 b = ...
    };

    // Guesstimate of cost; 4*N*log(K) where N == num_cols.
    // If K == N, assume the cost is N*log(K + 1).
    const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
                            Eigen::TensorOpCost::AddCost<T>();
    const double base_cost =
        cmp_cost *
        static_cast<double>(num_cols *
                            Eigen::numext::log2(static_cast<float>(k + 1)));
    const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
    const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
    const double total_cost = sort_cost + copy_cost;
    const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
                                 ? kint64max
                                 : static_cast<int64>(total_cost);
    auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
          final_cost, SortIndices);

    return Status::OK();
  }
};

}  // namespace functor

#define REGISTER_KERNELS_NAME(name, type)                       \
  REGISTER_KERNEL_BUILDER(                                      \
      Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      TopK<CPUDevice, type>)

#define REGISTER_KERNELS(type)       \
  REGISTER_KERNELS_NAME(TopK, type); \
  REGISTER_KERNELS_NAME(TopKV2, type)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS_NAME
#undef REGISTER_KERNELS

#ifdef GOOGLE_CUDA

namespace functor {
#define DECLARE_GPU_SPEC(T)                                                  \
  template <>                                                                \
  Status TopKFunctor<GPUDevice, T>::Compute(                                 \
      OpKernelContext* context, bool sorted, int k,                          \
      const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
      const int64 num_cols, typename TTypes<T, 2>::Tensor values,            \
      typename TTypes<int, 2>::Tensor indices);                              \
  extern template struct functor::TopKFunctor<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);

#undef DECLARE_GPU_SPEC

}  // namespace functor

#define REGISTER_KERNELS(type)                                   \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
      TopK<GPUDevice, type>)                                     \
  REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
                              .Device(DEVICE_GPU)                \
                              .TypeConstraint<type>("T")         \
                              .HostMemory("k"),                  \
                          TopK<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);

#undef REGISTER_KERNELS

#endif  // end GOOGLE_CUDA

}  // end namespace tensorflow