aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/split_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/split_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc208
1 files changed, 208 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
new file mode 100644
index 0000000000..18c4c648db
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for split.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class SplitOp : public XlaOpKernel {
+ public:
+ explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape index_shape = ctx->InputShape(0);
+ xla::Literal literal_index;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index));
+
+ int32 split_dim;
+ if (index_shape.dims() == 0) {
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
+ } else {
+ OP_REQUIRES(
+ ctx, index_shape.dims() == 1,
+ errors::InvalidArgument("split_index input to Split Op must be a "
+ "scalar or a vector with 1 element"));
+ OP_REQUIRES(
+ ctx, index_shape.dim_size(0) == 1,
+ errors::InvalidArgument("split_index input to Split Op must be a "
+ "scalar or a vector with 1 element"));
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {0});
+ }
+ const int32 num_split = num_outputs();
+ const TensorShape input_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(
+ ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
+ input_shape.dims(), "), but got ", split_dim));
+
+ OP_REQUIRES(
+ ctx, num_split > 0,
+ errors::InvalidArgument(
+ "Number of ways to split should be > 0, but got ", num_split));
+
+ OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0,
+ errors::InvalidArgument(
+ "Number of ways to split should evenly divide the split "
+ "dimension, but got split_dim ",
+ split_dim, " (size = ", input_shape.dim_size(split_dim),
+ ") ", "and num_split ", num_split));
+
+ // All the slices are the same size: this is the size along the
+ // split dimension.
+ const int32 slice_size = input_shape.dim_size(split_dim) / num_split;
+
+ // The vectors we will use to define the slice. The entry for the
+ // split dimensions varies for each output.
+ std::vector<int64> begin;
+ std::vector<int64> limits;
+ for (int i = 0; i < input_shape.dims(); ++i) {
+ // Initially set up the limits to be the full size of the input:
+ // the split dimension is filled in below.
+ int64 dim = input_shape.dim_size(i);
+ begin.push_back(0);
+ limits.push_back(dim);
+ }
+
+ auto input = ctx->Input(1);
+
+ // Create each of the outputs.
+ for (int i = 0; i < num_split; ++i) {
+ // Slice out the ith split from the split dimension.
+ begin[split_dim] = i * slice_size;
+ limits[split_dim] = (i + 1) * slice_size;
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ }
+ }
+};
+
+REGISTER_XLA_OP("Split", SplitOp);
+
+class SplitVOp : public XlaOpKernel {
+ public:
+ explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const int32 num_split = num_outputs();
+ const TensorShape index_shape = ctx->InputShape(2);
+ xla::Literal literal_index;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index));
+
+ int32 split_dim;
+ OP_REQUIRES(ctx, index_shape.dims() == 0,
+ errors::InvalidArgument("split_dim input to Split Op must be a "
+ "scalar"));
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
+
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ OP_REQUIRES(ctx, input_shape.dims() > 0,
+ errors::InvalidArgument("Can't split a 0 dimensional input"));
+
+ OP_REQUIRES(
+ ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
+ input_shape.dims(), "), but got ", split_dim));
+
+ OP_REQUIRES(
+ ctx, num_split > 0,
+ errors::InvalidArgument(
+ "Number of ways to split should be > 0, but got ", num_split));
+
+ // check that sizes are correct
+ int total_split_size = 0;
+ int neg_one_dim = -1;
+ std::vector<int64> split_sizes_vec(num_split, -1);
+ const TensorShape split_size_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, split_size_shape.dims() == 1 &&
+ split_size_shape.num_elements() == num_split,
+ errors::InvalidArgument(
+ "shape of tensor describing "
+ " the output must have dimension 1 and the same "
+ " number of elements as the output. Got ",
+ split_size_shape.dims(), "-D and ",
+ split_size_shape.num_elements(), " elements"));
+ // get the dimension of this split
+ xla::Literal split_size_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
+
+ for (int i = 0; i < num_split; ++i) {
+ int slice_size;
+ slice_size = xla::LiteralUtil::Get<int>(split_size_literal, {i});
+ if (slice_size == -1) {
+ OP_REQUIRES(
+ ctx, neg_one_dim == -1,
+ errors::InvalidArgument("Only one dimensions can have a value of"
+ "-1. Second one found at dimension ",
+ i));
+ neg_one_dim = i;
+ } else {
+ split_sizes_vec[i] = slice_size;
+ total_split_size += slice_size;
+ }
+ }
+
+ OP_REQUIRES(
+ ctx, (neg_one_dim == -1 &&
+ total_split_size == input_shape.dim_size(split_dim)) ||
+ (neg_one_dim >= 0 &&
+ total_split_size <= input_shape.dim_size(split_dim)),
+ errors::InvalidArgument("Determined shape must either match "
+ "input shape along split_dim exactly if "
+ "fully specified, or be less than the size of "
+ "the input along split_dim if not fully "
+ "specified. Got: ",
+ total_split_size));
+
+ if (neg_one_dim >= 0) {
+ split_sizes_vec[neg_one_dim] =
+ input_shape.dim_size(split_dim) - total_split_size;
+ }
+
+ // The vectors we will use to define the slice. The entry for the
+ // split dimensions varies for each output.
+ std::vector<int64> begin(input_shape.dims(), 0);
+ auto dim_sizes = input_shape.dim_sizes();
+ std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
+
+ for (int i = 0; i < num_split; ++i) {
+ TensorShape output_shape(input_shape);
+ int slice_size = split_sizes_vec[i];
+ output_shape.set_dim(split_dim, slice_size);
+
+ // Slice out the ith split from the split dimension.
+ limits[split_dim] = begin[split_dim] + slice_size;
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ begin[split_dim] = limits[split_dim];
+ }
+ }
+};
+
+REGISTER_XLA_OP("SplitV", SplitVOp);
+
+} // namespace
+} // namespace tensorflow