aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
blob: 61f4a54583d53b61d453eab620e29dac0b29ea6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <unordered_map>
#include <utility>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

template <typename T>
class UniqueOp : public OpKernel {
 public:
  explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {
    const DataType dt = DataTypeToEnum<T>::v();
    OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt, DT_INT32}));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(0);
    OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
                errors::InvalidArgument("unique expects a 1D vector."));
    auto Tin = input.vec<T>();
    const int N = Tin.size();

    Tensor* idx = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx));
    auto idx_vec = idx->template vec<int32>();

    std::unordered_map<T, int32> uniq;
    uniq.reserve(2 * N);
    for (int i = 0, j = 0; i < N; ++i) {
      auto it = uniq.insert(std::make_pair(Tin(i), j));
      idx_vec(i) = it.first->second;
      if (it.second) {
        ++j;
      }
    }
    int32 uniq_size = uniq.size();
    Tensor* output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(
                                0, TensorShape({uniq_size}), &output));
    auto output_vec = output->template vec<T>();

    for (auto it : uniq) {
      output_vec(it.second) = it.first;
    }
  }
};

#define REGISTER_UNIQUE(type)                                      \
  REGISTER_KERNEL_BUILDER(                                         \
      Name("Unique").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      UniqueOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
#undef REGISTER_UNIQUE
}  // namespace tensorflow