aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r--tensorflow/core/kernels/unique_op.cc113
1 files changed, 97 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 701c5f6d2b..d087784c8a 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <functional>
#include <unordered_map>
#include <utility>
@@ -21,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
@@ -33,8 +35,6 @@ class UniqueOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
- errors::InvalidArgument("unique expects a 1D vector."));
// TODO(dga): Make unique polymorphic for returning int32 and int64
// vectors to support large tensors.
OP_REQUIRES(context,
@@ -42,31 +42,102 @@ class UniqueOp : public OpKernel {
errors::InvalidArgument(
"unique does not support input tensors larger than ",
std::numeric_limits<int32>::max(), " elements"));
- auto Tin = input.vec<T>();
- const int64 N = static_cast<int64>(Tin.size());
+
+ int64 axis = 0;
+ std::vector<int64> new_sizes{1, input.NumElements(), 1};
+ if (context->num_inputs() == 1) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("unique expects a 1D vector."));
+ } else {
+ // In case of UniqueV2, the axis is a 1D vector. The purpose is
+ // to allow specifying either "no axis" or "axis". The `[]` means
+ // "no axis", while `[x]` means `axis = x`.
+ const Tensor& axis_tensor = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(axis_tensor.shape()),
+ errors::InvalidArgument("axis expects a 1D vector."));
+ OP_REQUIRES(
+ context, axis_tensor.NumElements() <= 1,
+ errors::InvalidArgument(
+ "axis does not support input tensors larger than 1 elements"));
+ if (axis_tensor.NumElements() == 0) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("unique expects a 1D vector."));
+ } else {
+ auto axis_vec = axis_tensor.vec<int64>();
+ axis = axis_vec(0);
+ axis = axis < 0 ? axis + input.dims() : axis;
+ OP_REQUIRES(context, 0 <= axis && axis < input.dims(),
+ errors::InvalidArgument("axis has to be between [0, ",
+ input.dims(), ")"));
+ if (axis > 0) {
+ for (int64 i = 0; i < axis; i++) {
+ new_sizes[0] *= input.dim_size(i);
+ }
+ }
+ new_sizes[1] = input.dim_size(axis);
+ if (axis + 1 < input.dims()) {
+ for (int64 i = axis + 1; i < input.dims(); i++) {
+ new_sizes[2] *= input.dim_size(i);
+ }
+ }
+ }
+ }
+
+ auto Tin = input.shaped<T, 3>(new_sizes);
Tensor* idx = nullptr;
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 1, input.shape(), &idx));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({Tin.dimension(1)}), &idx));
auto idx_vec = idx->template vec<TIndex>();
- std::unordered_map<T, TIndex> uniq;
- uniq.reserve(2 * N);
- for (int64 i = 0, j = 0; i < N; ++i) {
- auto it = uniq.insert(std::make_pair(Tin(i), j));
+ auto hash_fn = [&Tin](const int64& key) -> unsigned long {
+ size_t h = 0;
+ for (int64 i = 0; i < Tin.dimension(0); i++) {
+ for (int64 j = 0; j < Tin.dimension(2); j++) {
+ h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
+ }
+ }
+ return h;
+ };
+
+ auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
+ for (int64 i = 0; i < Tin.dimension(0); i++) {
+ for (int64 j = 0; j < Tin.dimension(2); j++) {
+ if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
+ uniq(0, hash_fn, equal_to_fn);
+
+ uniq.reserve(2 * Tin.dimension(1));
+
+ for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
+ auto it = uniq.insert(std::make_pair(i, j));
idx_vec(i) = it.first->second;
if (it.second) {
++j;
}
}
+
int64 uniq_size = static_cast<int64>(uniq.size());
+ new_sizes[1] = uniq_size;
+ TensorShape output_shape(input.shape());
+ output_shape.set_dim(axis, uniq_size);
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({uniq_size}), &output));
- auto output_vec = output->template vec<T>();
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ auto Tout = output->shaped<T, 3>(new_sizes);
for (auto it : uniq) {
- output_vec(it.second) = it.first;
+ for (int64 i = 0; i < Tin.dimension(0); i++) {
+ for (int64 j = 0; j < Tin.dimension(2); j++) {
+ Tout(i, it.second, j) = Tin(i, it.first, j);
+ }
+ }
}
if (num_outputs() > 2) {
@@ -74,7 +145,7 @@ class UniqueOp : public OpKernel {
2, TensorShape({uniq_size}), &output));
auto count_output_vec = output->template vec<TIndex>();
count_output_vec.setZero();
- for (int64 i = 0; i < N; ++i) {
+ for (int64 i = 0; i < Tin.dimension(1); ++i) {
count_output_vec(idx_vec(i))++;
}
}
@@ -92,6 +163,16 @@ class UniqueOp : public OpKernel {
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueOp<type, int64>); \
+ REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_idx"), \
+ UniqueOp<type, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ UniqueOp<type, int64>); \
REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
@@ -176,5 +257,5 @@ REGISTER_KERNEL_BUILDER(Name("Unique")
.HostMemory("y")
.HostMemory("idx"),
UniqueOp<int64, int64>);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow