aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/listdiff_op.cc
blob: bc3d6c491286885b530813ca3ba933fbce197a85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <string>
#include <unordered_set>
#include <utility>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"

namespace tensorflow {
template <typename T>
class ListDiffOp : public OpKernel {
 public:
  explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
    const DataType dt = DataTypeToEnum<T>::v();
    OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& x = context->input(0);
    const Tensor& y = context->input(1);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(x.shape()),
                errors::InvalidArgument("x should be a 1D vector."));

    OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()),
                errors::InvalidArgument("y should be a 1D vector."));

    std::unordered_set<T> y_set;
    const auto Ty = y.vec<T>();
    const int y_size = Ty.size();
    y_set.reserve(y_size);
    for (int i = 0; i < y_size; ++i) {
      y_set.insert(Ty(i));
    }

    // Compute the size of the output.
    const auto Tx = x.vec<T>();
    const int x_size = Tx.size();

    int out_size = 0;
    for (int i = 0; i < x_size; ++i) {
      if (y_set.count(Tx(i)) == 0) {
        ++out_size;
      }
    }

    // Allocate and populate outputs.
    Tensor* out = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, {out_size}, &out));
    auto Tout = out->vec<T>();

    Tensor* indices = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
    auto Tindices = indices->vec<int32>();

    for (int i = 0, p = 0; i < x_size; ++i) {
      if (y_set.count(Tx(i)) == 0) {
        Tout(p) = Tx(i);
        Tindices(p) = i;
        p++;
      }
    }
  }
};

#define REGISTER_LISTDIFF(type)                                      \
  REGISTER_KERNEL_BUILDER(                                           \
      Name("ListDiff").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ListDiffOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);
#undef REGISTER_LISTDIFF

}  // namespace tensorflow