diff options
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r-- | tensorflow/core/kernels/unique_op.cc | 62 |
1 files changed, 52 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index 6e51696d6f..701c5f6d2b 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -26,7 +26,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -template <typename T> +template <typename T, typename TIndex> class UniqueOp : public OpKernel { public: explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -48,9 +48,9 @@ class UniqueOp : public OpKernel { Tensor* idx = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 1, input.shape(), &idx)); - auto idx_vec = idx->template vec<int32>(); + auto idx_vec = idx->template vec<TIndex>(); - std::unordered_map<T, int32> uniq; + std::unordered_map<T, TIndex> uniq; uniq.reserve(2 * N); for (int64 i = 0, j = 0; i < N; ++i) { auto it = uniq.insert(std::make_pair(Tin(i), j)); @@ -72,7 +72,7 @@ class UniqueOp : public OpKernel { if (num_outputs() > 2) { OP_REQUIRES_OK(context, context->allocate_output( 2, TensorShape({uniq_size}), &output)); - auto count_output_vec = output->template vec<int32>(); + auto count_output_vec = output->template vec<TIndex>(); count_output_vec.setZero(); for (int64 i = 0; i < N; ++i) { count_output_vec(idx_vec(i))++; @@ -86,12 +86,22 @@ class UniqueOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("out_idx"), \ - UniqueOp<type>); \ + UniqueOp<type, int32>); \ + REGISTER_KERNEL_BUILDER(Name("Unique") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_idx"), \ + UniqueOp<type, int64>); \ REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("out_idx"), \ - UniqueOp<type>) + UniqueOp<type, int32>) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_idx"), \ + UniqueOp<type, int64>) TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); REGISTER_UNIQUE(string) #undef REGISTER_UNIQUE @@ -107,7 +117,15 @@ REGISTER_KERNEL_BUILDER(Name("Unique") .HostMemory("x") .HostMemory("y") .HostMemory("idx"), - UniqueOp<int32>); + UniqueOp<int32, int32>); +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp<int32, int64>); REGISTER_KERNEL_BUILDER(Name("Unique") .Device(DEVICE_GPU) .TypeConstraint<int64>("T") @@ -115,7 +133,15 @@ REGISTER_KERNEL_BUILDER(Name("Unique") .HostMemory("x") .HostMemory("y") .HostMemory("idx"), - UniqueOp<int64>); + UniqueOp<int64, int32>); +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_GPU) + .TypeConstraint<int64>("T") + .TypeConstraint<int64>("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp<int64, int64>); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("Unique") @@ -125,7 +151,7 @@ REGISTER_KERNEL_BUILDER(Name("Unique") .HostMemory("x") .HostMemory("y") .HostMemory("idx"), - UniqueOp<int32>); + UniqueOp<int32, int32>); REGISTER_KERNEL_BUILDER(Name("Unique") .Device(DEVICE_SYCL) .TypeConstraint<int64>("T") @@ -133,6 +159,22 @@ REGISTER_KERNEL_BUILDER(Name("Unique") .HostMemory("x") .HostMemory("y") .HostMemory("idx"), - UniqueOp<int64>); + UniqueOp<int64, int32>); +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp<int32, int64>); +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_SYCL) + .TypeConstraint<int64>("T") + .TypeConstraint<int64>("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp<int64, int64>); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |