diff options
author | 2016-12-01 16:28:32 -0800 | |
---|---|---|
committer | 2016-12-01 16:41:58 -0800 | |
commit | 1af94c269874440373c1d18d823110b1f5eabc19 (patch) | |
tree | 3f2d0d350f2f6eaffb3ef2c3dfdca4f0afced61d /tensorflow/core/kernels/set_kernels.cc | |
parent | 8b8e0aa63c9022d9e566f21876ed0cbec766af22 (diff) |
Moves metrics/sets and tensor_util.convert_to_tensor_or_sparse_tensor from contrib to core.
Change: 140793359
Diffstat (limited to 'tensorflow/core/kernels/set_kernels.cc')
-rw-r--r-- | tensorflow/core/kernels/set_kernels.cc | 733 |
1 files changed, 733 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc new file mode 100644 index 0000000000..61fe250206 --- /dev/null +++ b/tensorflow/core/kernels/set_kernels.cc @@ -0,0 +1,733 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops for operating with sets. They are not checked in +// to TensorFlow because we would first like to demonstrate successful +// end-to-end use of these ops in eval and polush the api a bit like taking two +// SparseTensor rather than on edense and one sparse. + +#define EIGEN_USE_THREADS + +#include <algorithm> +#include <numeric> +// TODO(ptucker): Consider switching back to hash_set - I had trouble getting it +// to work with string values. +#include <set> +#include <string> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +// Validate rank >= 2. +void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) { + const auto rank = shape.dims(); + OP_REQUIRES(ctx, rank >= 2, + errors::InvalidArgument("Invalid rank ", rank, ".")); +} + +// Return group shape, which is the 1st n-1 dimensions of shape. +const TensorShape GroupShape(OpKernelContext* ctx, + const TensorShape& input_shape) { + CheckRankAtLeast2(ctx, input_shape); + TensorShape shape(input_shape); + shape.RemoveDim(shape.dims() - 1); + return TensorShape(shape); +} + +// Validate sparse indices are valid. This is O(n), so use sparingly. +void CheckSparseTensorIndices(OpKernelContext* ctx, + const sparse::SparseTensor& st) { + OP_REQUIRES_OK(ctx, st.IndicesValid()); +} + +// Build `SparseTensor` from indices, values, and shape in inputs +// [base_index, base_index + 3), and validate its rank and indices. +sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx, + const int32 base_index, + bool validate_indices) { + // Assume row-major order. + const TensorShape shape = + TensorShape(ctx->input(base_index + 2).vec<int64>()); + CheckRankAtLeast2(ctx, shape); + std::vector<int64> order(shape.dims()); + std::iota(order.begin(), order.end(), 0); + + const sparse::SparseTensor st(ctx->input(base_index), + ctx->input(base_index + 1), shape, order); + if (validate_indices) { + CheckSparseTensorIndices(ctx, st); + } + return st; +} + +// TODO(ptucker): CheckGroup is just a sanity check on the result of +// SparseTensor.group, consider removing. +// `sparse_tensor_shape` is the shape of the `SparseTensor` from which group +// was created, and is used to sanity check the indices in `group'. +template <typename T> +void CheckGroup(OpKernelContext* ctx, const sparse::Group& group, + const TensorShape& sparse_tensor_shape) { + const auto& indices = group.indices(); + const auto& values = group.values<T>(); + + // Sanity check: group is non-empty, and indices and values are same size. + const auto num_values = values.dimension(0); + OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group.")); + OP_REQUIRES( + ctx, indices.dimension(0) == num_values, + errors::Internal("shape[0] of group indices ", indices.dimension(0), + " != values ", num_values, ".")); + + // Sanity check: valid indices. + const auto group_rank = indices.dimension(1); + const auto expected_rank = sparse_tensor_shape.dims(); + OP_REQUIRES(ctx, expected_rank == group_rank, + errors::Internal("Rank expected ", expected_rank, ", got ", + group_rank, ".")); + for (int32 j = 0; j < expected_rank; ++j) { + const auto dim_size = sparse_tensor_shape.dim_size(j); + OP_REQUIRES(ctx, dim_size > 0, errors::Internal("Invalid dim_size[", j, + "] = ", dim_size, ".")); + for (int64 i = 0; i < num_values; ++i) { + const auto index = indices(i, j); + OP_REQUIRES(ctx, dim_size > index, + errors::Internal("indices[", i, ", ", j, "] expected < ", + dim_size, ", got ", index, ".")); + } + } +} + +// This lets us calculate the row-major index into flattened output. +const gtl::InlinedVector<int64, 8> Strides(const TensorShape& shape) { + gtl::InlinedVector<int64, 8> result(shape.dims()); + int64 product = 1; + for (auto i = shape.dims() - 1; i >= 0; --i) { + result[i] = product; + product *= shape.dim_size(i); + } + return result; +} + +// TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to +// eliminate the intermediate `values` data structure - iterate once to +// determine `num_values`, allocate output tensors, then write results directly +// to output tensors. + +// TODO(ptucker): Consider sharding work across multiple threads. See +// SparseCrossOp for an example. + +// Output `SparseTensor` of shape `output_shape`. `sets` contains a map of +// group indices (i.e., values for all but the last dimension of `output_shape`) +// to set values, each of which will occupy the last dimension of +// `output_shape`. +template <typename T> +void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape, + const int64 num_values, + const std::map<std::vector<int64>, std::set<T>>& sets) { + // Allocate 3 output tensors for sparse data. + Tensor *out_indices_t, *out_values_t, *out_shape_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, TensorShape({num_values, output_shape.dims()}), + &out_indices_t)); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 2, TensorShape({output_shape.dims()}), &out_shape_t)); + auto out_indices_mat = out_indices_t->matrix<int64>(); + auto out_values_flat = out_values_t->vec<T>(); + + // For each set, write its indices and values to output tensors. + int64 value_index = 0; + for (auto it = sets.begin(); it != sets.end(); ++it) { + const auto& group_indices = it->first; + OP_REQUIRES( + ctx, group_indices.size() == output_shape.dims() - 1, + errors::Internal("Invalid number of indices ", group_indices.size(), + ", expected ", output_shape.dims() - 1, ".")); + const auto& set = it->second; + + // For each set item, write its indices and value to output tensors. + int64 group_value_index = 0; + for (auto value = set.begin(); value != set.end(); + ++value, ++value_index, ++group_value_index) { + // First n-1 dimensions are the group, last dimension is the position in + // the set. + for (int32 i = 0; i < group_indices.size(); ++i) { + out_indices_mat(value_index, i) = group_indices[i]; + } + out_indices_mat(value_index, group_indices.size()) = group_value_index; + + out_values_flat(value_index) = *value; + } + } + + // Write output shape. + auto out_shape_flat = out_shape_t->vec<int64>(); + for (int32 i = 0; i < output_shape.dims(); ++i) { + out_shape_flat(i) = output_shape.dim_size(i); + } +} + +bool ValidateIndicesFromContext(OpKernelConstruction* ctx) { + bool result; + if (ctx->GetAttr("validate_indices", &result).ok()) { + return result; + } + return true; +} + +// Populate `result` set from group in `tensor`. "Group" is defined by +// `group_indices`, which are values for the first n-1 dimensions of +// `input_tensor`. `input_strides` is provided to avoid recalculating it +// multiple times, and is used to calculate the flat index into `input_tensor` +// values. +template <typename T> +void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor, + const gtl::InlinedVector<int64, 8>& input_strides, + const std::vector<int64>& group_indices, + std::set<T>* result) { + OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1, + errors::Internal("group_indices.size ", group_indices.size(), + ", != input_strides.size-1 ", + input_strides.size() - 1, ".")); + result->clear(); + auto input_flat = input_tensor.flat<T>(); + const auto start = std::inner_product( + group_indices.begin(), group_indices.end(), input_strides.begin(), 0L); + const TensorShape& input_shape = input_tensor.shape(); + const auto end = start + input_shape.dim_size(input_shape.dims() - 1); + for (int64 i = start; i < end; ++i) { + result->insert(input_flat(i)); + } +} + +// Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the +// `SparseTensor` from which group was created, and is used to sanity check the +// indices in `group'. +template <typename T> +void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group, + const TensorShape& sparse_tensor_shape, + std::set<T>* result) { + CheckGroup<T>(ctx, group, sparse_tensor_shape); + result->clear(); + const auto& group_values = group.values<T>(); + for (int64 i = 0; i < group_values.size(); ++i) { + result->insert(group_values(i)); + } +} + +template <typename T> +class SetSizeOp : public OpKernel { + public: + explicit SetSizeOp(OpKernelConstruction* ctx) + : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {} + + void Compute(OpKernelContext* ctx) override; + + private: + const bool validate_indices_; +}; + +template <typename T> +void SetSizeOp<T>::Compute(OpKernelContext* ctx) { + const sparse::SparseTensor set_st = + SparseTensorFromContext(ctx, 0, validate_indices_); + + // Output shape is same as input except for last dimension, which reduces to + // the set size of values along that dimension. + const TensorShape output_shape = GroupShape(ctx, set_st.shape()); + const auto output_strides = Strides(output_shape); + + Tensor* out_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out_t)); + auto out = out_t->flat<int32>(); + out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0)); + + // Group by all but last dimension, create a set of group values, and add set + // size to output. + sparse::SparseTensor::VarDimArray group_ix(set_st.order(), 0, + set_st.order().size() - 1); + std::set<T> group_set; + for (const auto& group : set_st.group(group_ix)) { + PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set); + + const auto group_key = group.group(); + const auto output_index = std::inner_product( + group_key.begin(), group_key.end(), output_strides.begin(), 0L); + out(output_index) = group_set.size(); + } +} + +#define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + SetSizeOp<T>); +_SET_SIZE_REGISTER_KERNEL_BUILDER(int8); +_SET_SIZE_REGISTER_KERNEL_BUILDER(int16); +_SET_SIZE_REGISTER_KERNEL_BUILDER(int32); +_SET_SIZE_REGISTER_KERNEL_BUILDER(int64); +_SET_SIZE_REGISTER_KERNEL_BUILDER(uint8); +_SET_SIZE_REGISTER_KERNEL_BUILDER(uint16); +_SET_SIZE_REGISTER_KERNEL_BUILDER(string); +#undef _SET_SIZE_REGISTER_KERNEL_BUILDER + +enum InputTypes { + DENSE_DENSE = 0, + DENSE_SPARSE = 1, + SPARSE_SPARSE = 2, +}; + +enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 }; + +SetOperation SetOperationFromContext(OpKernelConstruction* ctx) { + string set_operation_str; + if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) { + ctx->CtxFailure(errors::InvalidArgument("Missing set_operation.")); + } else { + std::transform(set_operation_str.begin(), set_operation_str.end(), + set_operation_str.begin(), ::tolower); + if ("a-b" == set_operation_str) { + return A_MINUS_B; + } + if ("b-a" == set_operation_str) { + return B_MINUS_A; + } + if ("intersection" == set_operation_str) { + return INTERSECTION; + } + if ("union" != set_operation_str) { + ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ", + set_operation_str, ".")); + } + } + // NOTE: This is not the default, this function fails if no 'set_operation' + // attribute is provided. + return UNION; +} + +// Abstract base class for performing set operations across the last dimension +// of 2 input tensors. +template <typename T> +class SetOperationOp : public OpKernel { + public: + SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types) + : OpKernel(ctx), + set_operation_(SetOperationFromContext(ctx)), + validate_indices_(ValidateIndicesFromContext(ctx)), + input_types_(input_types) {} + + void Compute(OpKernelContext* ctx) override; + + private: + void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2, + std::set<T>* result) const; + void ComputeDenseToDense(OpKernelContext* ctx) const; + void ComputeDenseToSparse(OpKernelContext* ctx) const; + void ComputeSparseToSparse(OpKernelContext* ctx) const; + const SetOperation set_operation_; + const bool validate_indices_; + const InputTypes input_types_; +}; + +template <typename T> +void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1, + const std::set<T>& set2, + std::set<T>* result) const { + switch (set_operation_) { + case A_MINUS_B: + std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(), + std::inserter(*result, result->begin())); + break; + case B_MINUS_A: + std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(), + std::inserter(*result, result->begin())); + break; + case INTERSECTION: + std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), + std::inserter(*result, result->begin())); + break; + case UNION: + std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(), + std::inserter(*result, result->begin())); + break; + } +} + +// Validate shapes have the same dimensions. +void CheckShapesMatch(OpKernelContext* ctx, const TensorShape& shape1, + const TensorShape& shape2) { + OP_REQUIRES( + ctx, shape1 == shape2, + errors::InvalidArgument("Mismatched shapes ", shape1.DebugString(), + " vs ", shape2.DebugString(), ".")); +} + +// Validate ranks are the same, and all but last dimension are the same. +// Return GroupShape. +const TensorShape GroupShapeFromInputs(OpKernelContext* ctx, + const TensorShape& shape1, + const TensorShape& shape2) { + const TensorShape group_shape = GroupShape(ctx, shape1); + CheckShapesMatch(ctx, group_shape, GroupShape(ctx, shape2)); + return group_shape; +} + +// Split `flat_group_index` into separate dimensions based on `group_shape`. +void PopulateGroupIndices(const int64 flat_group_index, + const TensorShape& group_shape, + std::vector<int64>* group_indices) { + group_indices->clear(); + int64 running_flat_group_index = flat_group_index; + for (auto group_dim_index = group_shape.dims() - 1; group_dim_index >= 0; + --group_dim_index) { + const auto group_dim = group_shape.dim_size(group_dim_index); + group_indices->insert(group_indices->begin(), + running_flat_group_index % group_dim); + running_flat_group_index /= group_dim; + } +} + +// `ctx` contains set1 and set2 dense tensors. +// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, +// and outputing the result `SparseTensor`. A "group" is a collection of values +// with the same first n-1 dimensions in set1 and set2. +template <typename T> +void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const { + const Tensor& set1_t = ctx->input(0); + const Tensor& set2_t = ctx->input(1); + // The following should stay in sync with `_dense_to_dense_shape` shape + // assertions in python/ops/set_ops.py, and `SetShapeFn` for + // `DenseToDenseSetOperation` in ops/set_ops.cc. + const TensorShape group_shape = + GroupShapeFromInputs(ctx, set1_t.shape(), set2_t.shape()); + + const auto set1_strides = Strides(set1_t.shape()); + const auto set2_strides = Strides(set2_t.shape()); + + std::map<std::vector<int64>, std::set<T>> group_sets; + int64 num_result_values = 0; + int64 max_set_size = 0; + + std::set<T> set1_group_set; + std::set<T> set2_group_set; + std::vector<int64> group_indices; + for (int64 flat_group_index = 0; + flat_group_index < group_shape.num_elements(); ++flat_group_index) { + PopulateGroupIndices(flat_group_index, group_shape, &group_indices); + PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, + &set1_group_set); + PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices, + &set2_group_set); + + std::set<T> group_set; + ApplySetOperation(set1_group_set, set2_group_set, &group_set); + if (!group_set.empty()) { + group_sets[group_indices] = group_set; + const auto set_size = group_set.size(); + if (set_size > max_set_size) { + max_set_size = set_size; + } + num_result_values += set_size; + } + } + + TensorShape output_shape(group_shape); + output_shape.AddDim(max_set_size); + OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); +} + +// `ctx` contains dense set1 and sparse set2 tensors. +// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, +// and outputing the result `SparseTensor`. A "group" is a collection of values +// with the same first n-1 dimensions in set1 and set2. +template <typename T> +void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const { + const Tensor& set1_t = ctx->input(0); + const sparse::SparseTensor set2_st = + SparseTensorFromContext(ctx, 1, validate_indices_); + // The following should stay in sync with `_dense_to_sparse_shape` shape + // assertions in python/ops/set_ops.py, and `SetShapeFn` for + // `DenseToSparseSetOperation` in ops/set_ops.cc. + const TensorShape group_shape = + GroupShapeFromInputs(ctx, set1_t.shape(), set2_st.shape()); + + const auto set1_strides = Strides(set1_t.shape()); + + std::map<std::vector<int64>, std::set<T>> group_sets; + int64 num_result_values = 0; + int64 max_set_size = 0; + + std::set<T> set1_group_set; + std::set<T> set2_group_set; + auto set2_grouper = set2_st.group(sparse::SparseTensor::VarDimArray( + set2_st.order(), 0, set2_st.order().size() - 1)); + auto set2_group_it = set2_grouper.begin(); + std::vector<int64> group_indices; + for (int64 flat_group_index = 0; + flat_group_index < group_shape.num_elements(); ++flat_group_index) { + PopulateGroupIndices(flat_group_index, group_shape, &group_indices); + + // Get values from set1. + PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, + &set1_group_set); + + // Get values from set2, if applicable. + set2_group_set.clear(); + if (set2_group_it != set2_grouper.end()) { + const auto& group = *set2_group_it; + const auto set2_group_indices = group.group(); + OP_REQUIRES( + ctx, set2_group_indices.size() == group_indices.size(), + errors::InvalidArgument("Invalid number of group indices ", + set2_group_indices.size(), ", expected ", + group_indices.size(), ".")); + bool group_match = true; + for (int32 i = 0; group_match && (i < set2_group_indices.size()); ++i) { + if (set2_group_indices[i] != group_indices[i]) { + group_match = false; + } + } + if (group_match) { + PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(), + &set2_group_set); + ++set2_group_it; + } + } + + std::set<T> group_set; + ApplySetOperation(set1_group_set, set2_group_set, &group_set); + if (!group_set.empty()) { + group_sets[group_indices] = group_set; + const auto set_size = group_set.size(); + if (set_size > max_set_size) { + max_set_size = set_size; + } + num_result_values += set_size; + } + } + + TensorShape output_shape(group_shape); + output_shape.AddDim(max_set_size); + OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); +} + +// This is used to determine which group iterator is less than the other, based +// on row-major ordering of indices. +// An empty index list indicates end of iteration, which is interpreted as "max" +// for the purposes of comparison; i.e., non-empty < empty. +// Return 0 if both groups are empty, or both non-empty with the same values. +// Return <0 if set1 <= set2, or set2 is empty. +// Return >0 if set2 <= set1, or set1 is empty. +void CompareGroups(OpKernelContext* ctx, + const std::vector<int64>& set1_group_indices, + const std::vector<int64>& set2_group_indices, + int64* result) { + if (set1_group_indices.empty()) { + *result = set2_group_indices.empty() ? 0 : 1; + return; + } + if (set2_group_indices.empty()) { + *result = set1_group_indices.empty() ? 0 : -1; + return; + } + OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(), + errors::InvalidArgument("Mismatched group dims ", + set1_group_indices.size(), " vs ", + set2_group_indices.size(), ".")); + for (int32 i = 0; i < set1_group_indices.size(); ++i) { + *result = set1_group_indices[i] - set2_group_indices[i]; + if (*result != 0) { + return; + } + } +} + +// Empty indices vector represents iteration end in `CompareGroups`. +const std::vector<int64> GROUP_ITER_END; + +// `ctx` contains set1 and set2 sparse tensors. +// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, +// and outputing the result `SparseTensor`. A "group" is a collection of values +// with the same first n-1 dimensions in set1 and set2. +template <typename T> +void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const { + const sparse::SparseTensor set1_st = + SparseTensorFromContext(ctx, 0, validate_indices_); + const sparse::SparseTensor set2_st = + SparseTensorFromContext(ctx, 3, validate_indices_); + // The following should stay in sync with `_sparse_to_sparse_shape` shape + // assertions in python/ops/set_ops.py, and `SetShapeFn` for + // `SparseToSparseSetOperation` in ops/set_ops.cc. + const TensorShape group_shape = + GroupShapeFromInputs(ctx, set1_st.shape(), set2_st.shape()); + + const auto set1_strides = Strides(set1_st.shape()); + const auto set2_strides = Strides(set2_st.shape()); + + std::map<std::vector<int64>, std::set<T>> group_sets; + int64 num_result_values = 0; + int64 max_set_size = 0; + + std::set<T> set1_group_set; + std::set<T> set2_group_set; + auto set1_grouper = set1_st.group(sparse::SparseTensor::VarDimArray( + set1_st.order(), 0, set1_st.order().size() - 1)); + auto set1_group_it = set1_grouper.begin(); + auto set2_grouper = set2_st.group(sparse::SparseTensor::VarDimArray( + set2_st.order(), 0, set2_st.order().size() - 1)); + auto set2_group_it = set2_grouper.begin(); + + // Group by rows, and iterate over rows of both sets in parallel, creating a + // set for each row. + while ((set1_group_it != set1_grouper.end()) || + (set2_group_it != set2_grouper.end())) { + const std::vector<int64>& set1_group_indices = + (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END + : (*set1_group_it).group(); + const std::vector<int64>& set2_group_indices = + (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END + : (*set2_group_it).group(); + + int64 compare_groups; + CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups); + const std::vector<int64>* group_indices = nullptr; + + // Get values from set1, if applicable. + set1_group_set.clear(); + if (compare_groups <= 0) { + PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(), + &set1_group_set); + ++set1_group_it; + group_indices = &set1_group_indices; + } + + // Get values from set2, if applicable. + set2_group_set.clear(); + if (compare_groups >= 0) { + PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(), + &set2_group_set); + ++set2_group_it; + group_indices = &set2_group_indices; + } + + std::set<T> group_set; + ApplySetOperation(set1_group_set, set2_group_set, &group_set); + if (!group_set.empty()) { + group_sets[*group_indices] = group_set; + const auto set_size = group_set.size(); + if (set_size > max_set_size) { + max_set_size = set_size; + } + num_result_values += set_size; + } + } + + TensorShape output_shape(group_shape); + output_shape.AddDim(max_set_size); + OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); +} + +// Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result +// sparse tendor with [b, n3] values, where each row `i` contains the result of +// the set operation on elements from set1[i] and set2[i]. `n3` is the number +// of elements in that result row. +template <typename T> +void SetOperationOp<T>::Compute(OpKernelContext* ctx) { + switch (input_types_) { + case DENSE_DENSE: + ComputeDenseToDense(ctx); + break; + case DENSE_SPARSE: + ComputeDenseToSparse(ctx); + break; + case SPARSE_SPARSE: + ComputeSparseToSparse(ctx); + break; + } +} + +template <typename T> +class DenseToDenseSetOperationOp : public SetOperationOp<T> { + public: + explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx) + : SetOperationOp<T>(ctx, DENSE_DENSE) {} +}; + +#define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ + REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + DenseToDenseSetOperationOp<T>); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); +_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); +#undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER + +template <typename T> +class DenseToSparseSetOperationOp : public SetOperationOp<T> { + public: + explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx) + : SetOperationOp<T>(ctx, DENSE_SPARSE) {} +}; + +#define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ + REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + DenseToSparseSetOperationOp<T>); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); +_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); +#undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER + +template <typename T> +class SparseToSparseSetOperationOp : public SetOperationOp<T> { + public: + explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx) + : SetOperationOp<T>(ctx, SPARSE_SPARSE) {} +}; + +#define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ + REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + SparseToSparseSetOperationOp<T>); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); +_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); +#undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER + +} // namespace tensorflow |