aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/triangular_solve.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc936
1 files changed, 344 insertions, 592 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index b9f695ac4b..05dad759df 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -20,631 +20,383 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/math/math_util.h"
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
- const xla::XlaOp& a, xla::XlaOp b,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have different ranks: ",
- xla::ShapeUtil::HumanString(a_shape), " vs. ",
- xla::ShapeUtil::HumanString(b_shape));
- }
- const int ndims = xla::ShapeUtil::Rank(a_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve must have rank >= 2: ", ndims);
- }
- // The batch dimensions must be equal.
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- int64 b_size = b_shape.dimensions(i);
- if (a_size != b_size) {
- return errors::InvalidArgument(
- "Batch dimensions of arguments to TriangularSolve must be equal: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
- xla::ShapeUtil::HumanString(b_shape));
+// Get the diagonal blocks of the coefficient matrix
+xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a));
+ int ndims = xla::ShapeUtil::Rank(shape);
+ int64 n = xla::ShapeUtil::GetDimension(shape, -1);
+ int64 num_blocks = n / block_size;
+
+ xla::XlaOp diag_blocks;
+
+ // If the coefficient matrix is exactly the block size, we just add a
+ // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
+ if (n == block_size) {
+ std::vector<int64> permutation(ndims);
+ std::iota(permutation.begin(), permutation.end(), 1);
+ permutation.insert(permutation.end() - 2, 0);
+ return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
}
- batch_dimensions.push_back(a_size);
- }
-
- if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
- xla::ShapeUtil::GetDimension(a_shape, -2)) {
- return errors::InvalidArgument(
- "The 'a' arguments to TriangularSolve must be square matrices: ",
- xla::ShapeUtil::HumanString(a_shape));
- }
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have incompatible matrix shapes: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
- xla::ShapeUtil::HumanString(b_shape));
- }
-
- if (block_size < 1) {
- return errors::InvalidArgument(
- "block_size argument to TriangularSolve must be >= 1; got ",
- block_size);
- }
-
- std::map<int, xla::XlaComputation> base_computations;
- auto get_base_triangular_solve =
- [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
- xla::XlaComputation& computation = base_computations[k];
- if (computation.IsNull()) {
- std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
- tensorflow::strings::StrCat("trsm_base_", k));
-
- auto a_param = xla::Parameter(
- sub.get(), 0,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
- "a");
-
- std::array<int64, 2> b_lastd;
- if (left_side) {
- b_lastd = {k, n};
- } else {
- b_lastd = {m, k};
- }
- auto b_param = xla::Parameter(
- sub.get(), 1,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
- "b");
-
- // We use a left-looking or right-looking subroutine on the block diagonal
- // in the lower=true cases, while falling back to a recursive call in
- // others. The left-looking and right-looking subroutines are written with
- // a While loop and so yields much faster compile times. Moreover, they
- // can give higher performance on smaller (sub)problems.
- if (left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else if (!left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else {
- TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
- left_side, lower, transpose_a,
- conjugate_a,
- /*block_size=*/1)
- .status());
- }
- TF_ASSIGN_OR_RETURN(computation, sub->Build());
+ // We can grab entire blocks using gather
+ if (n > block_size) {
+ // Construct the starting indices of the diagonal blocks
+ auto gather_indices =
+ Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks),
+ xla::ConstantR0<int32>(builder, block_size)),
+ /*broadcast_sizes=*/{2}),
+ /*permutation=*/{1, 0});
+
+ // Gather the diagonal blocks
+ xla::GatherDimensionNumbers dim_numbers;
+ dim_numbers.add_output_window_dims(ndims - 1);
+ dim_numbers.add_output_window_dims(ndims);
+ dim_numbers.add_gather_dims_to_operand_dims(ndims - 2);
+ dim_numbers.add_gather_dims_to_operand_dims(ndims - 1);
+ dim_numbers.set_index_vector_dim(1);
+ diag_blocks = Gather(a, gather_indices, dim_numbers,
+ /*window_bounds=*/{block_size, block_size});
}
- return &computation;
- };
-
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Right-looking blocked triangular solve.
- // For an explanation of the algorithm, see the TRSM discussion in:
- // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
- // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
- // (2008): 4.
-
- // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
- // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
- // conjugate_a is True.
-
- if (!left_side && lower == transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < n; i += block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
- if (i + k < n) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
- } else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
- }
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
+ // The last block might be smaller than the block size,
+ // so we will need to pad it
+ if (n % block_size != 0) {
+ // Pad with zeros
+ auto last_blocks =
+ SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
+ xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims);
+ int64 padding = block_size - n % block_size;
+ config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
+ config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
+ last_blocks =
+ Pad(last_blocks, Zero(builder, shape.element_type()), config);
+
+ // Add a singleton dimension
+ // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
+ TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
+ builder->GetShape(last_blocks));
+ auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions());
+ auto last_blocks_dims = std::vector<int64>(ndims);
+ std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
+ last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
+ last_blocks = Reshape(last_blocks, last_blocks_dims);
+
+ // Concatenate with the other blocks if necessary
+ if (n > block_size) {
+ diag_blocks =
+ xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
+ } else {
+ diag_blocks = last_blocks;
}
}
- } else if (left_side && lower != transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < m; i += block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
- if (i + k < m) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
- } else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
- }
+ return diag_blocks;
+ });
+}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
- }
+xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = diag_blocks.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ // Input is a batch of square lower triangular square matrices. Its shape is
+ // (..., size, size). We resize this to (num_blocks, size, size).
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks));
+ int64 block_size = xla::ShapeUtil::GetDimension(shape, -1);
+ int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) /
+ tensorflow::MathUtil::IPow(block_size, 2);
+ diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
+
+ // The input must be triangular because we rely on that when doing
+ // multiplications later on
+ diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
+
+ // Rescale blocks to be unit triangular, but avoid dividing by
+ // zero (which can happen if the last block was padded) otherwise it will
+ // introduce nans which will propagate
+ auto diags = GetMatrixDiagonal(diag_blocks);
+ TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags));
+ auto one = ScalarLike(diags, 1);
+ auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions()));
+ diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
+ auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
+
+ // We can now use the fact that for an upper triangular matrix
+ // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
+ // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
+ // have been rescaled to be unit triangular, so L22 = L22' = 1.
+
+ // Initialize the output matrix with -1s on the diagonal. We use -1 instead
+ // of 1 because we cannot do matrix-vector multiplies with variable shapes
+ // inside of a loop, or do irregularly shaped in-place updates. Hence,
+ // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
+ // entire row i.e. we calculate
+ // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
+ // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
+ auto identity =
+ IdentityMatrix(builder, shape.element_type(), block_size, block_size);
+ auto neg_identity = -identity;
+
+ // The first or last diagonal element should be set to 1 instead of -1
+ // though, since we never update it
+ auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
+ auto start_index = (lower) ? 0 : block_size - 1;
+ auto output_block = DynamicUpdateSlice(
+ neg_identity, pos_one,
+ /*start_indices=*/xla::ConstantR1<int>(builder, 2, start_index));
+
+ // Broadcast diag([1, -1, -1, ...]) to every block
+ xla::XlaOp output = Broadcast(output_block,
+ /*broadcast_sizes=*/{num_blocks});
+
+ // Now we construct a loop that performs matrix-vector multiplications
+ // inverting the blocks one row at a time
+ std::vector<xla::Shape> tuple_shapes = {
+ // The loop iteration counter is a scalar, incremented each iteration.
+ xla::ShapeUtil::MakeShape(xla::S32, {}),
+ // The output has the shape of A, with one row updated each iteration.
+ xla::ShapeUtil::MakeShape(shape.element_type(),
+ {num_blocks, block_size, block_size}),
+ // The input is a loop invariant.
+ xla::ShapeUtil::MakeShape(shape.element_type(),
+ {num_blocks, block_size, block_size})};
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
+
+ auto init_i = One(builder, xla::S32);
+ auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks});
+
+ // Construct the loop condition function.
+ std::unique_ptr<xla::XlaBuilder> condb =
+ builder->CreateSubBuilder("InvertDiagCond");
+ {
+ auto i = GetTupleElement(
+ Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
+ Lt(i, xla::ConstantR0<int32>(condb.get(), block_size));
}
- } else if (!left_side && lower != transpose_a) {
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
- } else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
- }
+ TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+ // Construct the loop body function.
+ std::unique_ptr<xla::XlaBuilder> bodyb =
+ builder->CreateSubBuilder("InvertDiagBody");
+ {
+ auto input_tuple =
+ Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
+
+ auto i = GetTupleElement(input_tuple, 0);
+ auto body_out = GetTupleElement(input_tuple, 1);
+ auto body_input = GetTupleElement(input_tuple, 2);
+
+ auto zero = xla::ConstantR1<int32>(bodyb.get(), 1, 0);
+ auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
+ auto start_indices =
+ xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0);
+ auto input_row =
+ DynamicSlice(body_input, start_indices,
+ /*slice_sizes=*/{num_blocks, 1, block_size});
+
+ // We want -L21 L11^{-1}
+ xla::DotDimensionNumbers dnums;
+ dnums.add_lhs_batch_dimensions(0);
+ dnums.add_rhs_batch_dimensions(0);
+ dnums.add_lhs_contracting_dimensions(2);
+ dnums.add_rhs_contracting_dimensions(1);
+ auto update = -DotGeneral(input_row, body_out, dnums);
+
+ body_out = DynamicUpdateSlice(body_out, update, start_indices);
+
+ auto next_i = i + ScalarLike(i, 1);
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_input});
+ }
+ TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+ // Construct the While loop and return the result,
+ // return while_loop(cond_fun, body_fun, init)[1]
+ auto invert_while = While(cond, body, init);
+ auto inv_diag_blocks = GetTupleElement(invert_while, 1);
+
+ // Undo the scaling
+ inv_diag_blocks = Div(inv_diag_blocks, diags,
+ /*broadcast_dimensions=*/{0, 1});
+
+ // Reshape back to original batch major dimensions
+ return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions()));
+ });
+}
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {m, i}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
+xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaOp inv_diag_blocks,
+ bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
+ builder->GetShape(inv_diag_blocks));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1);
+
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ int64 ndims = xla::ShapeUtil::Rank(a_shape);
+ int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ int64 num_blocks = n / block_size + (n % block_size != 0);
+ int64 m_dim = (left_side) ? -1 : -2;
+ int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim);
+
+ // Initialize the solution
+ auto x = ZerosLike(b);
+
+ // This loop is unrolled for performance reasons, but it could be expressed
+ // rolled as well since the matrices are of the same size each iteration
+ for (int i = 0; i < num_blocks; i++) {
+ // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
+ // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
+ // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
+ // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
+
+ // Decide whether we go from first block to last or vice versa
+ auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i;
+
+ // Get the size of the inverse blocks (the last one might be smaller)
+ int64 block = (n % block_size != 0 && j + 1 == num_blocks)
+ ? n % block_size
+ : block_size;
+ auto inv_block =
+ MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
+ {j + 1, block, block}),
+ /*dimensions=*/{ndims - 2, ndims - 1}),
+ conjugate_a);
+
+ // Get the corresponding row of B
+ int64 k = std::min((j + 1) * block_size, n);
+ std::vector<int64> start = {j * block_size, 0};
+ std::vector<int64> end = {k, m};
+ if (!left_side) {
+ std::swap(start[0], start[1]);
+ std::swap(end[0], end[1]);
}
- }
- } else { // left_side && lower == transpose_a
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
+ auto b_row = SliceInMinorDims(b, start, end);
+
+ xla::XlaOp remainder;
+ if (i == 0) {
+ remainder = b_row;
} else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
+ // This matrix multiply involves a lot of multiplying with zero (namely,
+ // X[i * block_size:] = 0), but this is faster than slicing...
+ end = {k, n};
+ if (!left_side) {
+ std::swap(end[0], end[1]);
+ }
+ if (transpose_a) {
+ std::swap(start[0], start[1]);
+ std::swap(end[0], end[1]);
+ }
+ auto a_row =
+ MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
+ if (left_side) {
+ remainder = b_row - BatchDot(a_row, x, transpose_a, false);
} else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
+ remainder = b_row - BatchDot(x, a_row, false, transpose_a);
}
+ }
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {i, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
+ xla::XlaOp x_update;
+ auto zero = Zero(builder, xla::S32);
+ auto start_index =
+ xla::ConstantR0WithType(builder, xla::S32, j * block_size);
+ std::vector<xla::XlaOp> update_starts = {start_index, zero};
+ if (left_side) {
+ x_update = BatchDot(inv_block, remainder, transpose_a, false);
+ } else {
+ x_update = BatchDot(remainder, inv_block, false, transpose_a);
+ std::swap(update_starts[0], update_starts[1]);
}
+ x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
}
- }
- return output;
+ return x;
+ });
}
-xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
-
- // Allocate the output and set its first or last row,
- // output = np.zeros_like(b)
- // if transpose_a:
- // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
- // else:
- // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::XlaOp output = Zeros(builder, b_shape);
- {
- auto i = transpose_a ? m - 1 : 0;
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- auto update = xla::Div(b_slice, a_slice_conj);
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
- }
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (m-2, output, a, b)
- // else:
- // init = (1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1);
- auto init = xla::Tuple(builder, {init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i >= 0 if transpose_a else i < m
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
- {
- auto i = xla::GetTupleElement(
- xla::Parameter(condb.get(), 0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple"),
- 0);
- if (transpose_a) {
- xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
- } else {
- xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m));
+xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have different ranks: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs. ",
+ xla::ShapeUtil::HumanString(b_shape));
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
- // else:
- // a_row = a[..., i:i+1, :i]
- // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
- // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
- {
- auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple");
-
- // i, output, a, b = loop_carry
- auto i = xla::GetTupleElement(input_tuple, 0);
- auto body_out = xla::GetTupleElement(input_tuple, 1);
- auto body_a = xla::GetTupleElement(input_tuple, 2);
- auto body_b = xla::GetTupleElement(input_tuple, 3);
- auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
-
- // We'd like to implement this:
- // if transpose_a:
- // a_row = T(a[..., i+1:, i:i+1])
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a_row, body_out[..., i+1:, :]))
- // else:
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
- // But since we can't have intermediate array sizes depend on the loop
- // counter, we instead exploit the fact that we initialized the output to
- // all zeros and use that as zero-padding (doing unnecessary FLOPs).
- xla::XlaOp a_row;
- if (transpose_a) {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {zero, i}, {m, 1}));
- } else {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, zero}, {1, m}));
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve must have rank >= 2: ", ndims);
+ }
+ // The batch dimensions must be equal.
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ int64 b_size = b_shape.dimensions(i);
+ if (a_size != b_size) {
+ return errors::InvalidArgument(
+ "Batch dimensions of arguments to TriangularSolve must be equal: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
+ }
+ batch_dimensions.push_back(a_size);
}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(
- auto result_row_slice,
- DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
- auto result_row = xla::Sub(result_row_slice, b_update);
-
- // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_elt_conj,
- MaybeConjugate(bodyb.get(), a_elt, conjugate_a));
- auto div_result = xla::Div(result_row, a_elt_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {i, zero}));
-
- // if transpose_a:
- // return (i - 1, body_out, a, b)
- // else:
- // return (i + 1, body_out, a, b)
- auto next_i =
- xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1));
- xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = xla::While(cond, body, init);
- return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
-}
-xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (0, output, a, b)
- // else:
- // init = (n-1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1);
- auto init = xla::Tuple(builder, {init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i < n if transpose_a else i >= 0
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
- {
- auto i = xla::GetTupleElement(
- xla::Parameter(condb.get(), 0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple"),
- 0);
- if (transpose_a) {
- xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n));
- } else {
- xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+ if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
+ xla::ShapeUtil::GetDimension(a_shape, -2)) {
+ return errors::InvalidArgument(
+ "The 'a' arguments to TriangularSolve must be square matrices: ",
+ xla::ShapeUtil::HumanString(a_shape));
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2)
- // else:
- // a_row = a[..., :, i:i+1]
- // result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
- // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
- {
- auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple");
-
- // i, output, a, b = loop_carry
- auto i = xla::GetTupleElement(input_tuple, 0);
- auto body_out = xla::GetTupleElement(input_tuple, 1);
- auto body_a = xla::GetTupleElement(input_tuple, 2);
- auto body_b = xla::GetTupleElement(input_tuple, 3);
- auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
-
- // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
- // i:i+1]) But since we can't have intermediate array sizes depend on the
- // loop counter, we instead exploit the fact that we initialized the output
- // to all zeros and use that as zero-padding (doing unnecessary FLOPs).
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- // result = b - np.matmul(output, a)
- auto result = xla::Sub(body_b, b_update);
- // result_row = result[..., :, i:i+1]
- TF_ASSIGN_OR_RETURN(
- auto result_row,
- DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1}));
-
- // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_ii_conj,
- MaybeConjugate(bodyb.get(), a_ii, conjugate_a));
- auto div_result = xla::Div(result_row, a_ii_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {zero, i}));
-
- // if transpose_a:
- // return (i + 1, body_out, a, b)
- // else:
- // return (i - 1, body_out, a, b)
- auto next_i =
- xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1));
- xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = xla::While(cond, body, init);
- return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have incompatible matrix shapes: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
+ }
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to TriangularSolve must be >= 1; got ",
+ block_size);
+ }
+
+ // We find the diagonal blocks of the coefficient matrix
+ auto diag_blocks = DiagonalBlocks(a, block_size);
+
+ // We invert these blocks in parallel using batched matrix-vector products
+ auto inv_diag_blocks =
+ InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a);
+
+ // We now find the solution using GEMMs
+ auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
+ lower, transpose_a, conjugate_a);
+
+ return x;
+ });
}
} // namespace tensorflow