aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r--tensorflow/core/kernels/unique_op.cc61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
new file mode 100644
index 0000000000..61f4a54583
--- /dev/null
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -0,0 +1,61 @@
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+class UniqueOp : public OpKernel {
+ public:
+ explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("unique expects a 1D vector."));
+ auto Tin = input.vec<T>();
+ const int N = Tin.size();
+
+ Tensor* idx = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx));
+ auto idx_vec = idx->template vec<int32>();
+
+ std::unordered_map<T, int32> uniq;
+ uniq.reserve(2 * N);
+ for (int i = 0, j = 0; i < N; ++i) {
+ auto it = uniq.insert(std::make_pair(Tin(i), j));
+ idx_vec(i) = it.first->second;
+ if (it.second) {
+ ++j;
+ }
+ }
+ int32 uniq_size = uniq.size();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({uniq_size}), &output));
+ auto output_vec = output->template vec<T>();
+
+ for (auto it : uniq) {
+ output_vec(it.second) = it.first;
+ }
+ }
+};
+
+#define REGISTER_UNIQUE(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Unique").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ UniqueOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
+#undef REGISTER_UNIQUE
+} // namespace tensorflow