aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r--tensorflow/core/kernels/unique_op.cc62
1 files changed, 52 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 6e51696d6f..701c5f6d2b 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -26,7 +26,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-template <typename T>
+template <typename T, typename TIndex>
class UniqueOp : public OpKernel {
public:
explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -48,9 +48,9 @@ class UniqueOp : public OpKernel {
Tensor* idx = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 1, input.shape(), &idx));
- auto idx_vec = idx->template vec<int32>();
+ auto idx_vec = idx->template vec<TIndex>();
- std::unordered_map<T, int32> uniq;
+ std::unordered_map<T, TIndex> uniq;
uniq.reserve(2 * N);
for (int64 i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
@@ -72,7 +72,7 @@ class UniqueOp : public OpKernel {
if (num_outputs() > 2) {
OP_REQUIRES_OK(context, context->allocate_output(
2, TensorShape({uniq_size}), &output));
- auto count_output_vec = output->template vec<int32>();
+ auto count_output_vec = output->template vec<TIndex>();
count_output_vec.setZero();
for (int64 i = 0; i < N; ++i) {
count_output_vec(idx_vec(i))++;
@@ -86,12 +86,22 @@ class UniqueOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
- UniqueOp<type>); \
+ UniqueOp<type, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Unique") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ UniqueOp<type, int64>); \
REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
- UniqueOp<type>)
+ UniqueOp<type, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ UniqueOp<type, int64>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
REGISTER_UNIQUE(string)
#undef REGISTER_UNIQUE
@@ -107,7 +117,15 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("x")
.HostMemory("y")
.HostMemory("idx"),
- UniqueOp<int32>);
+ UniqueOp<int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("Unique")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("out_idx")
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("idx"),
+ UniqueOp<int32, int64>);
REGISTER_KERNEL_BUILDER(Name("Unique")
.Device(DEVICE_GPU)
.TypeConstraint<int64>("T")
@@ -115,7 +133,15 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("x")
.HostMemory("y")
.HostMemory("idx"),
- UniqueOp<int64>);
+ UniqueOp<int64, int32>);
+REGISTER_KERNEL_BUILDER(Name("Unique")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .TypeConstraint<int64>("out_idx")
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("idx"),
+ UniqueOp<int64, int64>);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("Unique")
@@ -125,7 +151,7 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("x")
.HostMemory("y")
.HostMemory("idx"),
- UniqueOp<int32>);
+ UniqueOp<int32, int32>);
REGISTER_KERNEL_BUILDER(Name("Unique")
.Device(DEVICE_SYCL)
.TypeConstraint<int64>("T")
@@ -133,6 +159,22 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("x")
.HostMemory("y")
.HostMemory("idx"),
- UniqueOp<int64>);
+ UniqueOp<int64, int32>);
+REGISTER_KERNEL_BUILDER(Name("Unique")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("out_idx")
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("idx"),
+ UniqueOp<int32, int64>);
+REGISTER_KERNEL_BUILDER(Name("Unique")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int64>("T")
+ .TypeConstraint<int64>("out_idx")
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("idx"),
+ UniqueOp<int64, int64>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow