aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r--tensorflow/core/kernels/unique_op.cc113
1 files changed, 16 insertions, 97 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index d087784c8a..701c5f6d2b 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <functional>
#include <unordered_map>
#include <utility>
@@ -22,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
@@ -35,6 +33,8 @@ class UniqueOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("unique expects a 1D vector."));
// TODO(dga): Make unique polymorphic for returning int32 and int64
// vectors to support large tensors.
OP_REQUIRES(context,
@@ -42,102 +42,31 @@ class UniqueOp : public OpKernel {
errors::InvalidArgument(
"unique does not support input tensors larger than ",
std::numeric_limits<int32>::max(), " elements"));
-
- int64 axis = 0;
- std::vector<int64> new_sizes{1, input.NumElements(), 1};
- if (context->num_inputs() == 1) {
- OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
- errors::InvalidArgument("unique expects a 1D vector."));
- } else {
- // In case of UniqueV2, the axis is a 1D vector. The purpose is
- // to allow specifying either "no axis" or "axis". The `[]` means
- // "no axis", while `[x]` means `axis = x`.
- const Tensor& axis_tensor = context->input(1);
- OP_REQUIRES(context, TensorShapeUtils::IsVector(axis_tensor.shape()),
- errors::InvalidArgument("axis expects a 1D vector."));
- OP_REQUIRES(
- context, axis_tensor.NumElements() <= 1,
- errors::InvalidArgument(
- "axis does not support input tensors larger than 1 elements"));
- if (axis_tensor.NumElements() == 0) {
- OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
- errors::InvalidArgument("unique expects a 1D vector."));
- } else {
- auto axis_vec = axis_tensor.vec<int64>();
- axis = axis_vec(0);
- axis = axis < 0 ? axis + input.dims() : axis;
- OP_REQUIRES(context, 0 <= axis && axis < input.dims(),
- errors::InvalidArgument("axis has to be between [0, ",
- input.dims(), ")"));
- if (axis > 0) {
- for (int64 i = 0; i < axis; i++) {
- new_sizes[0] *= input.dim_size(i);
- }
- }
- new_sizes[1] = input.dim_size(axis);
- if (axis + 1 < input.dims()) {
- for (int64 i = axis + 1; i < input.dims(); i++) {
- new_sizes[2] *= input.dim_size(i);
- }
- }
- }
- }
-
- auto Tin = input.shaped<T, 3>(new_sizes);
+ auto Tin = input.vec<T>();
+ const int64 N = static_cast<int64>(Tin.size());
Tensor* idx = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 1, TensorShape({Tin.dimension(1)}), &idx));
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 1, input.shape(), &idx));
auto idx_vec = idx->template vec<TIndex>();
- auto hash_fn = [&Tin](const int64& key) -> unsigned long {
- size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
- h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
- }
- }
- return h;
- };
-
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
- if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
- return false;
- }
- }
- }
- return true;
- };
-
- std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
- uniq(0, hash_fn, equal_to_fn);
-
- uniq.reserve(2 * Tin.dimension(1));
-
- for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
- auto it = uniq.insert(std::make_pair(i, j));
+ std::unordered_map<T, TIndex> uniq;
+ uniq.reserve(2 * N);
+ for (int64 i = 0, j = 0; i < N; ++i) {
+ auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
if (it.second) {
++j;
}
}
-
int64 uniq_size = static_cast<int64>(uniq.size());
- new_sizes[1] = uniq_size;
- TensorShape output_shape(input.shape());
- output_shape.set_dim(axis, uniq_size);
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- auto Tout = output->shaped<T, 3>(new_sizes);
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({uniq_size}), &output));
+ auto output_vec = output->template vec<T>();
for (auto it : uniq) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
- Tout(i, it.second, j) = Tin(i, it.first, j);
- }
- }
+ output_vec(it.second) = it.first;
}
if (num_outputs() > 2) {
@@ -145,7 +74,7 @@ class UniqueOp : public OpKernel {
2, TensorShape({uniq_size}), &output));
auto count_output_vec = output->template vec<TIndex>();
count_output_vec.setZero();
- for (int64 i = 0; i < Tin.dimension(1); ++i) {
+ for (int64 i = 0; i < N; ++i) {
count_output_vec(idx_vec(i))++;
}
}
@@ -163,16 +92,6 @@ class UniqueOp : public OpKernel {
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueOp<type, int64>); \
- REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("out_idx"), \
- UniqueOp<type, int32>); \
- REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int64>("out_idx"), \
- UniqueOp<type, int64>); \
REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
@@ -257,5 +176,5 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("y")
.HostMemory("idx"),
UniqueOp<int64, int64>);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow