diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/split_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/split_op.cc | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc new file mode 100644 index 0000000000..18c4c648db --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for split. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class SplitOp : public XlaOpKernel { + public: + explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape index_shape = ctx->InputShape(0); + xla::Literal literal_index; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); + + int32 split_dim; + if (index_shape.dims() == 0) { + split_dim = xla::LiteralUtil::Get<int>(literal_index, {}); + } else { + OP_REQUIRES( + ctx, index_shape.dims() == 1, + errors::InvalidArgument("split_index input to Split Op must be a " + "scalar or a vector with 1 element")); + OP_REQUIRES( + ctx, index_shape.dim_size(0) == 1, + errors::InvalidArgument("split_index input to Split Op must be a " + "scalar or a vector with 1 element")); + split_dim = xla::LiteralUtil::Get<int>(literal_index, {0}); + } + const int32 num_split = num_outputs(); + const TensorShape input_shape = ctx->InputShape(1); + + OP_REQUIRES( + ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("0 <= split_dim < number of input dimensions (", + input_shape.dims(), "), but got ", split_dim)); + + OP_REQUIRES( + ctx, num_split > 0, + errors::InvalidArgument( + "Number of ways to split should be > 0, but got ", num_split)); + + OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim, " (size = ", input_shape.dim_size(split_dim), + ") ", "and num_split ", num_split)); + + // All the slices are the same size: this is the size along the + // split dimension. + const int32 slice_size = input_shape.dim_size(split_dim) / num_split; + + // The vectors we will use to define the slice. The entry for the + // split dimensions varies for each output. + std::vector<int64> begin; + std::vector<int64> limits; + for (int i = 0; i < input_shape.dims(); ++i) { + // Initially set up the limits to be the full size of the input: + // the split dimension is filled in below. + int64 dim = input_shape.dim_size(i); + begin.push_back(0); + limits.push_back(dim); + } + + auto input = ctx->Input(1); + + // Create each of the outputs. + for (int i = 0; i < num_split; ++i) { + // Slice out the ith split from the split dimension. + begin[split_dim] = i * slice_size; + limits[split_dim] = (i + 1) * slice_size; + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + } + } +}; + +REGISTER_XLA_OP("Split", SplitOp); + +class SplitVOp : public XlaOpKernel { + public: + explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const int32 num_split = num_outputs(); + const TensorShape index_shape = ctx->InputShape(2); + xla::Literal literal_index; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); + + int32 split_dim; + OP_REQUIRES(ctx, index_shape.dims() == 0, + errors::InvalidArgument("split_dim input to Split Op must be a " + "scalar")); + split_dim = xla::LiteralUtil::Get<int>(literal_index, {}); + + xla::ComputationDataHandle input = ctx->Input(0); + const TensorShape input_shape = ctx->InputShape(0); + + OP_REQUIRES(ctx, input_shape.dims() > 0, + errors::InvalidArgument("Can't split a 0 dimensional input")); + + OP_REQUIRES( + ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("0 <= split_dim < number of input dimensions (", + input_shape.dims(), "), but got ", split_dim)); + + OP_REQUIRES( + ctx, num_split > 0, + errors::InvalidArgument( + "Number of ways to split should be > 0, but got ", num_split)); + + // check that sizes are correct + int total_split_size = 0; + int neg_one_dim = -1; + std::vector<int64> split_sizes_vec(num_split, -1); + const TensorShape split_size_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, + errors::InvalidArgument( + "shape of tensor describing " + " the output must have dimension 1 and the same " + " number of elements as the output. Got ", + split_size_shape.dims(), "-D and ", + split_size_shape.num_elements(), " elements")); + // get the dimension of this split + xla::Literal split_size_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + + for (int i = 0; i < num_split; ++i) { + int slice_size; + slice_size = xla::LiteralUtil::Get<int>(split_size_literal, {i}); + if (slice_size == -1) { + OP_REQUIRES( + ctx, neg_one_dim == -1, + errors::InvalidArgument("Only one dimensions can have a value of" + "-1. Second one found at dimension ", + i)); + neg_one_dim = i; + } else { + split_sizes_vec[i] = slice_size; + total_split_size += slice_size; + } + } + + OP_REQUIRES( + ctx, (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), + errors::InvalidArgument("Determined shape must either match " + "input shape along split_dim exactly if " + "fully specified, or be less than the size of " + "the input along split_dim if not fully " + "specified. Got: ", + total_split_size)); + + if (neg_one_dim >= 0) { + split_sizes_vec[neg_one_dim] = + input_shape.dim_size(split_dim) - total_split_size; + } + + // The vectors we will use to define the slice. The entry for the + // split dimensions varies for each output. + std::vector<int64> begin(input_shape.dims(), 0); + auto dim_sizes = input_shape.dim_sizes(); + std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end()); + + for (int i = 0; i < num_split; ++i) { + TensorShape output_shape(input_shape); + int slice_size = split_sizes_vec[i]; + output_shape.set_dim(split_dim, slice_size); + + // Slice out the ith split from the split dimension. + limits[split_dim] = begin[split_dim] + slice_size; + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + begin[split_dim] = limits[split_dim]; + } + } +}; + +REGISTER_XLA_OP("SplitV", SplitVOp); + +} // namespace +} // namespace tensorflow |