aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/concat_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc210
1 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
new file mode 100644
index 0000000000..96ef2ac20c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Concat Ops.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// --------------------------------------------------------------------------
+class ConcatBaseOp : public XlaOpKernel {
+ public:
+ ConcatBaseOp(OpKernelConstruction* c, int axis_index)
+ : XlaOpKernel(c), axis_index_(axis_index) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_);
+ OP_REQUIRES(
+ ctx, IsLegacyScalar(concat_dim_tensor_shape),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_tensor_shape.DebugString()));
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
+ // TODO(annarev): add a helper to support int64 input.
+ const int32 concat_dim = xla::LiteralUtil::Get<int>(literal, {});
+
+ std::vector<xla::ComputationDataHandle> values;
+ std::vector<TensorShape> shapes;
+ OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
+ const int N = values.size();
+ const int input_dims = shapes[0].dims();
+ const TensorShape& input_shape = shapes[0];
+
+ int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
+ OP_REQUIRES(ctx,
+ (0 <= axis && axis < input_dims) ||
+ (allow_legacy_scalars() && concat_dim == 0),
+ errors::InvalidArgument(
+ "ConcatOp : Expected concatenating dimensions in the range "
+ "[",
+ -input_dims, ", ", input_dims, "), but got ", concat_dim));
+
+ // Make a vector holding the ComputationDataHandles for each of
+ // the inputs that has non-zero elements.
+ std::vector<xla::ComputationDataHandle> input_data;
+ int output_concat_dim = 0;
+ const bool input_is_scalar = IsLegacyScalar(input_shape);
+ for (int i = 0; i < N; ++i) {
+ xla::ComputationDataHandle handle = values[i];
+ const TensorShape& in_shape = shapes[i];
+ const bool in_is_scalar = IsLegacyScalar(in_shape);
+ OP_REQUIRES(
+ ctx,
+ in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
+ errors::InvalidArgument(
+ "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
+ input_shape.DebugString(), " vs. shape[", i, "] = ",
+ in_shape.DebugString()));
+ if (in_shape.dims() == 0) {
+ // Inputs that come in as scalars must be reshaped to 1-vectors.
+ input_data.push_back(ctx->builder()->Reshape(handle, {1}));
+ } else {
+ input_data.push_back(handle);
+ }
+ output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+ }
+
+ VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
+ ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis));
+ }
+
+ private:
+ int axis_index_;
+};
+
+class ConcatOp : public ConcatBaseOp {
+ public:
+ explicit ConcatOp(OpKernelConstruction* c)
+ : ConcatBaseOp(c, /* axis_index */ 0) {}
+};
+
+// ConcatV2 operation is the same as Concat except 'concat_dim'
+// is the last input instead of the first and renamed to 'axis'.
+class ConcatV2Op : public ConcatBaseOp {
+ public:
+ explicit ConcatV2Op(OpKernelConstruction* c)
+ : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {}
+};
+
+REGISTER_XLA_OP("Concat", ConcatOp);
+REGISTER_XLA_OP("ConcatV2", ConcatV2Op);
+
+class ConcatOffsetOp : public XlaOpKernel {
+ public:
+ explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape concat_dim_shape = ctx->InputShape(0);
+ OP_REQUIRES(
+ ctx, IsLegacyScalar(concat_dim_shape),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_shape.DebugString()));
+ for (int i = 1; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)),
+ errors::InvalidArgument("input ", i,
+ " should be a vector, but got shape ",
+ ctx->InputShape(i).DebugString()));
+ }
+ // Suppose a Concat() op needs to Concatenate N tensors, each of
+ // which has the same number of dimensions. Their shapes match
+ // except the concat dimension.
+ //
+ // E.g., say, we want to concatenate 3 tensors in the 2nd
+ // dimension, and their shapes are:
+ //
+ // [2, 2, 5, 7]
+ // [2, 3, 5, 7]
+ // [2, 4, 5, 7]
+ //
+ // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
+ // [2,9,5,7]. We will compute the cumulative sum along the 2nd
+ // dimension to figure out each input's offset in the concatenated
+ // output:
+ // [0, 0, 0, 0]
+ // [0, 2, 0, 0]
+ // [0, 5, 0, 0]
+ const int32 N = ctx->num_inputs() - 1;
+ const TensorShape inp0_shape = ctx->InputShape(1);
+ xla::Literal inp0_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal));
+ const int64 dims = inp0_shape.num_elements();
+
+ xla::Literal concat_dim_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
+ const int64 cdim = xla::LiteralUtil::Get<int>(concat_dim_literal, {});
+
+ VLOG(1) << "ConcatOffset " << cdim << "," << dims;
+ int32 axis = cdim < 0 ? cdim + dims : cdim;
+ OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
+ errors::InvalidArgument("Concat dim is out of range: ", axis,
+ " vs. ", dims));
+ int32 offset = 0;
+ for (int i = 0; i < N; ++i) {
+ const TensorShape inp_shape = ctx->InputShape(1 + i);
+ OP_REQUIRES(ctx, dims == inp_shape.num_elements(),
+ errors::InvalidArgument("input ", i, " should contain ", dims,
+ " elements, but got",
+ inp_shape.num_elements()));
+ xla::Literal inp_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal));
+
+ Tensor out_constant(DT_INT32, TensorShape({dims}));
+ auto out_vec = out_constant.vec<int32>();
+ for (int64 j = 0; j < dims; ++j) {
+ if (j == axis) {
+ out_vec(j) = offset;
+ offset += xla::LiteralUtil::Get<int>(inp_literal, {j});
+ } else {
+ const int32 inp0_element =
+ xla::LiteralUtil::Get<int>(inp0_literal, {j});
+ const int32 inp_element =
+ xla::LiteralUtil::Get<int>(inp_literal, {j});
+ OP_REQUIRES(
+ ctx, (inp0_element == inp_element),
+ errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",
+ inp0_element, " vs. ", inp_element));
+ out_vec(j) = 0;
+ }
+ }
+
+ ctx->SetConstantOutput(i, out_constant);
+ }
+ }
+};
+
+REGISTER_XLA_OP("ConcatOffset", ConcatOffsetOp);
+
+} // namespace
+} // namespace tensorflow