diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/triangular_solve.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/triangular_solve.cc | 936 |
1 files changed, 344 insertions, 592 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b9f695ac4b..05dad759df 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,631 +20,383 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { -xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); +// Get the diagonal blocks of the coefficient matrix +xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); + int ndims = xla::ShapeUtil::Rank(shape); + int64 n = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = n / block_size; + + xla::XlaOp diag_blocks; + + // If the coefficient matrix is exactly the block size, we just add a + // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] + if (n == block_size) { + std::vector<int64> permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 1); + permutation.insert(permutation.end() - 2, 0); + return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map<int, xla::XlaComputation> base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr<xla::XlaComputation*> { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = xla::Parameter( - sub.get(), 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array<int64, 2> b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = xla::Parameter( - sub.get(), 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); - } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + // We can grab entire blocks using gather + if (n > block_size) { + // Construct the starting indices of the diagonal blocks + auto gather_indices = + Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), + xla::ConstantR0<int32>(builder, block_size)), + /*broadcast_sizes=*/{2}), + /*permutation=*/{1, 0}); + + // Gather the diagonal blocks + xla::GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(ndims - 1); + dim_numbers.add_output_window_dims(ndims); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.set_index_vector_dim(1); + diag_blocks = Gather(a, gather_indices, dim_numbers, + /*window_bounds=*/{block_size, block_size}); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // The last block might be smaller than the block size, + // so we will need to pad it + if (n % block_size != 0) { + // Pad with zeros + auto last_blocks = + SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); + xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + int64 padding = block_size - n % block_size; + config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); + config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); + last_blocks = + Pad(last_blocks, Zero(builder, shape.element_type()), config); + + // Add a singleton dimension + // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(last_blocks)); + auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + auto last_blocks_dims = std::vector<int64>(ndims); + std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); + last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); + last_blocks = Reshape(last_blocks, last_blocks_dims); + + // Concatenate with the other blocks if necessary + if (n > block_size) { + diag_blocks = + xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + } else { + diag_blocks = last_blocks; } } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); - } + return diag_blocks; + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); - } +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + auto one = ScalarLike(diags, 1); + auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = (lower) ? 0 : block_size - 1; + auto output_block = DynamicUpdateSlice( + neg_identity, pos_one, + /*start_indices=*/xla::ConstantR1<int>(builder, 2, start_index)); + + // Broadcast diag([1, -1, -1, ...]) to every block + xla::XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector<xla::Shape> tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of A, with one row updated each iteration. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, xla::S32); + auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr<xla::XlaBuilder> condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, xla::ConstantR0<int32>(condb.get(), block_size)); } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); - } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); - } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr<xla::XlaBuilder> bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = xla::ConstantR1<int32>(bodyb.get(), 1, 0); + auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; + auto start_indices = + xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + auto input_row = + DynamicSlice(body_input, start_indices, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + xla::DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + auto update = -DotGeneral(input_row, body_out, dnums); + + body_out = DynamicUpdateSlice(body_out, update, start_indices); + + auto next_i = i + ScalarLike(i, 1); + xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); +xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, + xla::XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + int64 ndims = xla::ShapeUtil::Rank(a_shape); + int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + int64 num_blocks = n / block_size + (n % block_size != 0); + int64 m_dim = (left_side) ? -1 : -2; + int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + + // Initialize the solution + auto x = ZerosLike(b); + + // This loop is unrolled for performance reasons, but it could be expressed + // rolled as well since the matrices are of the same size each iteration + for (int i = 0; i < num_blocks; i++) { + // High-level intuition: We have B[i] = L[i] @ X. Since L is upper + // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split + // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which + // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] + + // Decide whether we go from first block to last or vice versa + auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; + + // Get the size of the inverse blocks (the last one might be smaller) + int64 block = (n % block_size != 0 && j + 1 == num_blocks) + ? n % block_size + : block_size; + auto inv_block = + MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0}, + {j + 1, block, block}), + /*dimensions=*/{ndims - 2, ndims - 1}), + conjugate_a); + + // Get the corresponding row of B + int64 k = std::min((j + 1) * block_size, n); + std::vector<int64> start = {j * block_size, 0}; + std::vector<int64> end = {k, m}; + if (!left_side) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); + auto b_row = SliceInMinorDims(b, start, end); + + xla::XlaOp remainder; + if (i == 0) { + remainder = b_row; } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + // This matrix multiply involves a lot of multiplying with zero (namely, + // X[i * block_size:] = 0), but this is faster than slicing... + end = {k, n}; + if (!left_side) { + std::swap(end[0], end[1]); + } + if (transpose_a) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); + } + auto a_row = + MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); + if (left_side) { + remainder = b_row - BatchDot(a_row, x, transpose_a, false); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + remainder = b_row - BatchDot(x, a_row, false, transpose_a); } + } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + xla::XlaOp x_update; + auto zero = Zero(builder, xla::S32); + auto start_index = + xla::ConstantR0WithType(builder, xla::S32, j * block_size); + std::vector<xla::XlaOp> update_starts = {start_index, zero}; + if (left_side) { + x_update = BatchDot(inv_block, remainder, transpose_a, false); + } else { + x_update = BatchDot(remainder, inv_block, false, transpose_a); + std::swap(update_starts[0], update_starts[1]); } + x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); } - } - return output; + return x; + }); } -xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = xla::Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector<xla::Shape> tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr<xla::XlaBuilder> condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); - } else { - xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m)); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr<xla::XlaBuilder> bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); - - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector<int64> batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + batch_dimensions.push_back(a_size); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = xla::Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = xla::Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); - - // if transpose_a: - // return (i - 1, body_out, a, b) - // else: - // return (i + 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); -} -xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector<xla::Shape> tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr<xla::XlaBuilder> condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n)); - } else { - xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr<xla::XlaBuilder> bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = xla::Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = xla::Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); - - // if transpose_a: - // return (i + 1, body_out, a, b) - // else: - // return (i - 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // We find the diagonal blocks of the coefficient matrix + auto diag_blocks = DiagonalBlocks(a, block_size); + + // We invert these blocks in parallel using batched matrix-vector products + auto inv_diag_blocks = + InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + + // We now find the solution using GEMMs + auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, + lower, transpose_a, conjugate_a); + + return x; + }); } } // namespace tensorflow |