aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/diag_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc142
1 files changed, 20 insertions, 122 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 17bf0c069c..22cda27567 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -26,10 +28,10 @@ namespace tensorflow {
namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
-xla::StatusOr<xla::XlaOp> CreateDiagonal(
- const xla::XlaOp& input, int64 last_dim_size,
- tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
- xla::XlaBuilder* builder) {
+xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
+ gtl::ArraySlice<int64> other_dims,
+ xla::PrimitiveType element_type) {
+ xla::XlaBuilder* builder = input.builder();
// Create two matrices that have the following forms, and compare them:
//
// [[0, 0, 0, 0] [[0, 1, 2, 3]
@@ -39,9 +41,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
//
// This produces a predicate matrix of the right size, with "true" on the
// diagonal.
- xla::XlaOp iota;
- TF_RETURN_IF_ERROR(
- XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size);
xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
@@ -68,12 +68,9 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
- xla::PrimitiveType element_type;
- TF_RETURN_IF_ERROR(
- DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
auto broadcast_shape =
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
- xla::XlaOp zeros = Zeros(builder, broadcast_shape);
+ xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape);
input_broadcast = xla::Add(input_broadcast, zeros);
return xla::Select(mask, input_broadcast, zeros);
@@ -84,8 +81,6 @@ class DiagOp : public XlaOpKernel {
explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("Diag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -108,10 +103,8 @@ class DiagOp : public XlaOpKernel {
input = xla::Reshape(input, {size});
// Create an R2 with the R1 diagonal.
- auto diag_or_status =
- CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- xla::XlaOp diag = diag_or_status.ValueOrDie();
+ xla::XlaOp diag =
+ CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0));
// Reshapes to the final shape.
std::vector<int64> new_dims(dims.size() * 2);
@@ -130,8 +123,6 @@ class DiagPartOp : public XlaOpKernel {
explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -157,37 +148,13 @@ class DiagPartOp : public XlaOpKernel {
new_dims.push_back(dims[i]);
}
- xla::XlaOp diag = ctx->Input(0);
-
- // TODO(b/30878775): use Slice with strides when supported, in place of
- // the Pad -> Reshape -> Slice.
-
- // Picture:
- // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
- // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
- // [0, 0, 3, 0] [3, 0, 0, 0, 0],
- // [0, 0, 0, 4]] [4, 0, 0, 0, 0]]
- // and then slice out the first column.
-
- // Flattens the input to 1D.
- int64 size = input_shape.num_elements();
- diag = xla::Reshape(diag, {size});
-
- // Adds padding after the last element of 'new_size'.
- xla::PaddingConfig config;
- auto* dim = config.add_dimensions();
- dim->set_edge_padding_high(new_size);
- auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = xla::Pad(diag, zero, config);
-
- // Reshapes so the diagonal is now in the first column.
- diag = xla::Reshape(diag, {new_size, new_size + 1});
+ xla::XlaOp input = ctx->Input(0);
- // Slices out the first column and reshapes to the final shape.
- diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
- diag = xla::Reshape(diag, new_dims);
+ xla::XlaOp output = xla::Reshape(
+ xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})),
+ new_dims);
- ctx->SetOutput(0, diag);
+ ctx->SetOutput(0, output);
}
};
@@ -198,8 +165,6 @@ class MatrixDiagOp : public XlaOpKernel {
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -209,17 +174,15 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
tensorflow::gtl::ArraySlice<int64> other_dims(dims);
other_dims.pop_back();
- auto diag_or_status =
- CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- diag = diag_or_status.ValueOrDie();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
+ ctx->input_xla_type(0));
ctx->SetOutput(0, diag);
}
};
@@ -231,8 +194,6 @@ class MatrixDiagPartOp : public XlaOpKernel {
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -240,71 +201,8 @@ class MatrixDiagPartOp : public XlaOpKernel {
errors::InvalidArgument("Expected 2 <= dims, got shape ",
input_shape.DebugString()));
- xla::XlaOp diag = ctx->Input(0);
-
- int last_dim = dims.size() - 1;
- int64 last_dim_size = dims[last_dim];
-
- // The smaller of the last two dimension sizes.
- int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]);
-
- // TODO(b/30878775): use Slice with strides when supported, in place of
- // the Pad -> Reshape -> Slice.
-
- // Picture: for each 2D matrix in the tensor's last two dimensions:
- // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
- // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
- // [0, 0, 3, 0]] [3, 0, 0, 0, 0],
- // and then slice out the first column.
- //
- // Another example, with tall and narrow input.
- // [[1, 0] pad and reshape to [[1, 0, 0],
- // [0, 2] =================> [2, 0, 0]]
- // [0, 0]
- // [0, 0]]
-
- // Collapses the last two dimensions.
- std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
- flattened_dims.back() *= dims.back();
- diag = xla::Reshape(diag, flattened_dims);
-
- // Slices or pads the last dimension to 'target_size'.
- int64 actual_size = flattened_dims.back();
- int64 target_size = smaller_dim_size * (last_dim_size + 1);
- if (actual_size < target_size) {
- xla::PaddingConfig config =
- xla::MakeNoPaddingConfig(flattened_dims.size());
- auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
- dim->set_edge_padding_high(target_size - actual_size);
- auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = xla::Pad(diag, zero, config);
- } else if (actual_size > target_size) {
- std::vector<int64> start(flattened_dims.size(), 0);
- std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
- std::vector<int64> strides(flattened_dims.size(), 1);
- limits[flattened_dims.size() - 1] = target_size;
- diag = xla::Slice(diag, start, limits, strides);
- }
-
- // Reshape so the target values are in the first position of the last
- // dimension.
- std::vector<int64> unflattened_dims(dims.begin(), dims.end());
- dims[last_dim - 1] = smaller_dim_size;
- dims[last_dim] = last_dim_size + 1;
- diag = xla::Reshape(diag, dims);
-
- // Slices out the first column and reshapes to the final shape.
- std::vector<int64> start(dims.size(), 0);
- std::vector<int64> limits(dims.begin(), dims.end());
- std::vector<int64> strides(dims.size(), 1);
- limits[last_dim] = 1;
- diag = xla::Slice(diag, start, limits, strides);
-
- // Collapses away the last dimension.
- dims.pop_back();
- diag = xla::Reshape(diag, dims);
-
- ctx->SetOutput(0, diag);
+ xla::XlaOp input = ctx->Input(0);
+ ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
}
};