1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
#include <string>
#include <unordered_set>
#include <utility>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
template <typename T>
class ListDiffOp : public OpKernel {
public:
explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
const DataType dt = DataTypeToEnum<T>::v();
OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
}
void Compute(OpKernelContext* context) override {
const Tensor& x = context->input(0);
const Tensor& y = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsVector(x.shape()),
errors::InvalidArgument("x should be a 1D vector."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()),
errors::InvalidArgument("y should be a 1D vector."));
std::unordered_set<T> y_set;
const auto Ty = y.vec<T>();
const int y_size = Ty.size();
y_set.reserve(y_size);
for (int i = 0; i < y_size; ++i) {
y_set.insert(Ty(i));
}
// Compute the size of the output.
const auto Tx = x.vec<T>();
const int x_size = Tx.size();
int out_size = 0;
for (int i = 0; i < x_size; ++i) {
if (y_set.count(Tx(i)) == 0) {
++out_size;
}
}
// Allocate and populate outputs.
Tensor* out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, {out_size}, &out));
auto Tout = out->vec<T>();
Tensor* indices = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
auto Tindices = indices->vec<int32>();
for (int i = 0, p = 0; i < x_size; ++i) {
if (y_set.count(Tx(i)) == 0) {
Tout(p) = Tx(i);
Tindices(p) = i;
p++;
}
}
}
};
#define REGISTER_LISTDIFF(type) \
REGISTER_KERNEL_BUILDER( \
Name("ListDiff").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ListDiffOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);
#undef REGISTER_LISTDIFF
} // namespace tensorflow
|