aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sparse_concat_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_concat_op.cc139
1 files changed, 139 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse_concat_op.cc b/tensorflow/core/kernels/sparse_concat_op.cc
new file mode 100644
index 0000000000..72c267a47d
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_concat_op.cc
@@ -0,0 +1,139 @@
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SparseConcatOp : public OpKernel {
+ public:
+ explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpInputList inds;
+ OP_REQUIRES_OK(context, context->input_list("indices", &inds));
+ const int N = inds.size();
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
+ errors::InvalidArgument(
+ "Input indices should be a matrix but received shape ",
+ inds[i].shape().DebugString(), " at position ", i));
+ }
+
+ OpInputList vals;
+ OP_REQUIRES_OK(context, context->input_list("values", &vals));
+ OP_REQUIRES(context, vals.size() == N,
+ errors::InvalidArgument("Expected ", N, " input values, got ",
+ vals.size()));
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
+ errors::InvalidArgument(
+ "Input values should be a vector but received shape ",
+ vals[i].shape().DebugString(), " at position ", i));
+ }
+
+ OpInputList shapes;
+ OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
+ OP_REQUIRES(context, shapes.size() == N,
+ errors::InvalidArgument("Expected ", N, " input shapes, got ",
+ shapes.size()));
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
+ errors::InvalidArgument(
+ "Input shapes should be a vector but received shape ",
+ shapes[i].shape().DebugString(), " at position ", i));
+ }
+
+ const TensorShape input_shape(shapes[0].vec<int64>());
+ OP_REQUIRES(
+ context, concat_dim_ >= 0 && concat_dim_ < input_shape.dims(),
+ errors::InvalidArgument("Concat dimension must be between 0 and rank (",
+ input_shape.dims(), "), got ", concat_dim_));
+ for (int i = 1; i < N; ++i) {
+ const TensorShape current_shape(shapes[i].vec<int64>());
+ OP_REQUIRES(context, current_shape.dims() == input_shape.dims(),
+ errors::InvalidArgument(
+ "Ranks of all input tensors must match: expected ",
+ input_shape.dims(), " but got ", current_shape.dims(),
+ " at position ", i));
+ for (int j = 0; j < input_shape.dims(); ++j) {
+ if (j != concat_dim_) {
+ OP_REQUIRES(
+ context, input_shape.dim_size(j) == current_shape.dim_size(j),
+ errors::InvalidArgument(
+ "Input shapes must match: expected ", input_shape.dim_size(j),
+ " for dimension ", j, " but got ", current_shape.dim_size(j),
+ " at position ", i));
+ }
+ }
+ }
+
+ // The input and output sparse tensors are assumed to be ordered along
+ // increasing dimension number. But in order for concat to work properly,
+ // order[0] must be concat_dim. So we will reorder the inputs to the
+ // concat ordering, concatenate, then reorder back to the standard order.
+ // We make a deep copy of the input tensors to ensure that the in-place
+ // reorder doesn't create race conditions for other ops that may be
+ // concurrently reading the indices and values tensors.
+
+ gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
+ std::iota(std_order.begin(), std_order.end(), 0);
+
+ std::vector<int64> concat_order;
+ concat_order.reserve(input_shape.dims());
+ concat_order.push_back(concat_dim_);
+ for (int j = 0; j < input_shape.dims(); ++j) {
+ if (j != concat_dim_) {
+ concat_order.push_back(j);
+ }
+ }
+
+ std::vector<sparse::SparseTensor> sp_inputs;
+ for (int i = 0; i < N; ++i) {
+ const TensorShape current_shape(shapes[i].vec<int64>());
+ sp_inputs.emplace_back(tensor::DeepCopy(inds[i]),
+ tensor::DeepCopy(vals[i]), current_shape,
+ std_order);
+ sp_inputs[i].Reorder<T>(concat_order);
+ }
+
+ sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
+ concat.Reorder<T>(std_order);
+
+ context->set_output(0, concat.indices());
+ context->set_output(1, concat.values());
+
+ Tensor* output_shape_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 2, TensorShape({concat.shape().dims()}),
+ &output_shape_out));
+ auto output_shape = output_shape_out->vec<int64>();
+ for (int j = 0; j < concat.shape().dims(); ++j) {
+ output_shape(j) = concat.shape().dim_size(j);
+ }
+ }
+
+ private:
+ int concat_dim_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseConcatOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+} // namespace tensorflow