diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/diag_op.cc | 142 |
1 files changed, 20 insertions, 122 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 17bf0c069c..22cda27567 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -26,10 +28,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr<xla::XlaOp> CreateDiagonal( - const xla::XlaOp& input, int64 last_dim_size, - tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx, - xla::XlaBuilder* builder) { +xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, + gtl::ArraySlice<int64> other_dims, + xla::PrimitiveType element_type) { + xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -39,9 +41,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); @@ -68,12 +68,9 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; - xla::PrimitiveType element_type; - TF_RETURN_IF_ERROR( - DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::XlaOp zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape); input_broadcast = xla::Add(input_broadcast, zeros); return xla::Select(mask, input_broadcast, zeros); @@ -84,8 +81,6 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -108,10 +103,8 @@ class DiagOp : public XlaOpKernel { input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. - auto diag_or_status = - CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::XlaOp diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0)); // Reshapes to the final shape. std::vector<int64> new_dims(dims.size() * 2); @@ -130,8 +123,6 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -157,37 +148,13 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::XlaOp diag = ctx->Input(0); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0] [3, 0, 0, 0, 0], - // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] - // and then slice out the first column. - - // Flattens the input to 1D. - int64 size = input_shape.num_elements(); - diag = xla::Reshape(diag, {size}); - - // Adds padding after the last element of 'new_size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_edge_padding_high(new_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = xla::Pad(diag, zero, config); - - // Reshapes so the diagonal is now in the first column. - diag = xla::Reshape(diag, {new_size, new_size + 1}); + xla::XlaOp input = ctx->Input(0); - // Slices out the first column and reshapes to the final shape. - diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = xla::Reshape(diag, new_dims); + xla::XlaOp output = xla::Reshape( + xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})), + new_dims); - ctx->SetOutput(0, diag); + ctx->SetOutput(0, output); } }; @@ -198,8 +165,6 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -209,17 +174,15 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); tensorflow::gtl::ArraySlice<int64> other_dims(dims); other_dims.pop_back(); - auto diag_or_status = - CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - diag = diag_or_status.ValueOrDie(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, + ctx->input_xla_type(0)); ctx->SetOutput(0, diag); } }; @@ -231,8 +194,6 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -240,71 +201,8 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); - - int last_dim = dims.size() - 1; - int64 last_dim_size = dims[last_dim]; - - // The smaller of the last two dimension sizes. - int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: for each 2D matrix in the tensor's last two dimensions: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0]] [3, 0, 0, 0, 0], - // and then slice out the first column. - // - // Another example, with tall and narrow input. - // [[1, 0] pad and reshape to [[1, 0, 0], - // [0, 2] =================> [2, 0, 0]] - // [0, 0] - // [0, 0]] - - // Collapses the last two dimensions. - std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1); - flattened_dims.back() *= dims.back(); - diag = xla::Reshape(diag, flattened_dims); - - // Slices or pads the last dimension to 'target_size'. - int64 actual_size = flattened_dims.back(); - int64 target_size = smaller_dim_size * (last_dim_size + 1); - if (actual_size < target_size) { - xla::PaddingConfig config = - xla::MakeNoPaddingConfig(flattened_dims.size()); - auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); - dim->set_edge_padding_high(target_size - actual_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = xla::Pad(diag, zero, config); - } else if (actual_size > target_size) { - std::vector<int64> start(flattened_dims.size(), 0); - std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); - std::vector<int64> strides(flattened_dims.size(), 1); - limits[flattened_dims.size() - 1] = target_size; - diag = xla::Slice(diag, start, limits, strides); - } - - // Reshape so the target values are in the first position of the last - // dimension. - std::vector<int64> unflattened_dims(dims.begin(), dims.end()); - dims[last_dim - 1] = smaller_dim_size; - dims[last_dim] = last_dim_size + 1; - diag = xla::Reshape(diag, dims); - - // Slices out the first column and reshapes to the final shape. - std::vector<int64> start(dims.size(), 0); - std::vector<int64> limits(dims.begin(), dims.end()); - std::vector<int64> strides(dims.size(), 1); - limits[last_dim] = 1; - diag = xla::Slice(diag, start, limits, strides); - - // Collapses away the last dimension. - dims.pop_back(); - diag = xla::Reshape(diag, dims); - - ctx->SetOutput(0, diag); + xla::XlaOp input = ctx->Input(0); + ctx->SetOutput(0, xla::GetMatrixDiagonal(input)); } }; |