diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/diag_op.cc | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc new file mode 100644 index 0000000000..b89109ff6a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -0,0 +1,255 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class DiagOp : public XlaOpKernel { + public: + explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + + auto dims = input_shape.dim_sizes(); + OP_REQUIRES(ctx, !dims.empty(), + errors::InvalidArgument("Expected 1 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + // Picture: + // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] + // [0, 2, 0, 0] + // [0, 0, 3, 0] + // [0, 0, 0, 4]] + + // Flattens the input to 1D. + int64 size = input_shape.num_elements(); + diag = builder->Reshape(diag, {size}); + + // Adds inter-element padding of 'size'. + xla::PaddingConfig config; + auto* dim = config.add_dimensions(); + dim->set_interior_padding(size); + diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + + // Reshapes to the final shape. + std::vector<int64> new_dims(dims.size() * 2); + std::copy(dims.begin(), dims.end(), new_dims.begin()); + std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); + diag = builder->Reshape(diag, new_dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("Diag", DiagOp); + +class DiagPartOp : public XlaOpKernel { + public: + explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + auto dims = input_shape.dim_sizes(); + + int num_dims = dims.size(); + const int out_dims = num_dims / 2; + + OP_REQUIRES(ctx, 2 <= num_dims, + errors::InvalidArgument("Expected 2 <= dims, got shape ", + input_shape.DebugString())); + OP_REQUIRES(ctx, num_dims % 2 == 0, + errors::InvalidArgument("The input tensor must have even rank; " + "got shape ", + input_shape.DebugString())); + int64 new_size = 1; + std::vector<int64> new_dims; + for (int i = 0; i < out_dims; i++) { + OP_REQUIRES( + ctx, dims[i] == dims[i + out_dims], + errors::InvalidArgument("Invalid shape ", input_shape.DebugString(), + ": dimensions ", i, " and ", i + out_dims, + " do not match.")); + new_size *= dims[i]; + new_dims.push_back(dims[i]); + } + + xla::ComputationDataHandle diag = ctx->Input(0); + + // TODO(b/30878775): use Slice with strides when supported, in place of + // the Pad -> Reshape -> Slice. + + // Picture: + // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], + // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], + // [0, 0, 3, 0] [3, 0, 0, 0, 0], + // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] + // and then slice out the first column. + + // Flattens the input to 1D. + int64 size = input_shape.num_elements(); + diag = builder->Reshape(diag, {size}); + + // Adds padding after the last element of 'new_size'. + xla::PaddingConfig config; + auto* dim = config.add_dimensions(); + dim->set_edge_padding_high(new_size); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + diag = builder->Pad(diag, zero, config); + + // Reshapes so the diagonal is now in the first column. + diag = builder->Reshape(diag, {new_size, new_size + 1}); + + // Slices out the first column and reshapes to the final shape. + diag = builder->Slice(diag, {0, 0}, {new_size, 1}); + diag = builder->Reshape(diag, new_dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("DiagPart", DiagPartOp); + +class MatrixDiagOp : public XlaOpKernel { + public: + explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + + auto dims = input_shape.dim_sizes(); + OP_REQUIRES(ctx, !dims.empty(), + errors::InvalidArgument("Expected 1 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + int last_dim = dims.size() - 1; + int64 last_dim_size = input_shape.dim_size(last_dim); + + // Adds inter-element padding of 'last_dim_size' to the last dimension. + xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); + auto* dim = config.mutable_dimensions(last_dim); + dim->set_interior_padding(last_dim_size); + diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + + // Reshapes to the final shape. + dims.push_back(last_dim_size); + diag = builder->Reshape(diag, dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("MatrixDiag", MatrixDiagOp); + +class MatrixDiagPartOp : public XlaOpKernel { + public: + explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + auto dims = input_shape.dim_sizes(); + + OP_REQUIRES(ctx, 2 <= dims.size(), + errors::InvalidArgument("Expected 2 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + int last_dim = dims.size() - 1; + int64 last_dim_size = dims[last_dim]; + + // The smaller of the last two dimension sizes. + int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); + + // TODO(b/30878775): use Slice with strides when supported, in place of + // the Pad -> Reshape -> Slice. + + // Picture: for each 2D matrix in the tensor's last two dimensions: + // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], + // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], + // [0, 0, 3, 0]] [3, 0, 0, 0, 0], + // and then slice out the first column. + // + // Another example, with tall and narrow input. + // [[1, 0] pad and reshape to [[1, 0, 0], + // [0, 2] =================> [2, 0, 0]] + // [0, 0] + // [0, 0]] + + // Collapses the last two dimensions. + std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1); + flattened_dims.back() *= dims.back(); + diag = builder->Reshape(diag, flattened_dims); + + // Slices or pads the last dimension to 'target_size'. + int64 actual_size = flattened_dims.back(); + int64 target_size = smaller_dim_size * (last_dim_size + 1); + if (actual_size < target_size) { + xla::PaddingConfig config = + xla::MakeNoPaddingConfig(flattened_dims.size()); + auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); + dim->set_edge_padding_high(target_size - actual_size); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + diag = builder->Pad(diag, zero, config); + } else if (actual_size > target_size) { + std::vector<int64> start(flattened_dims.size(), 0); + std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); + limits[flattened_dims.size() - 1] = target_size; + diag = builder->Slice(diag, start, limits); + } + + // Reshape so the target values are in the first position of the last + // dimension. + std::vector<int64> unflattened_dims(dims.begin(), dims.end()); + dims[last_dim - 1] = smaller_dim_size; + dims[last_dim] = last_dim_size + 1; + diag = builder->Reshape(diag, dims); + + // Slices out the first column and reshapes to the final shape. + std::vector<int64> start(dims.size(), 0); + std::vector<int64> limits(dims.begin(), dims.end()); + limits[last_dim] = 1; + diag = builder->Slice(diag, start, limits); + + // Collapses away the last dimension. + dims.pop_back(); + diag = builder->Reshape(diag, dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("MatrixDiagPart", MatrixDiagPartOp); + +} // namespace +} // namespace tensorflow |