From 45bafe9a3589fc735c22c3c703f8689ea9c1e71e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Apr 2018 17:41:33 -0700 Subject: [XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compiler/aot: - xla::ComputationBuilder -> xla::XlaBuilder - xla::ComputationDataHandle -> xla::XlaOp - xla::Computation -> xla::XlaComputation - xla::CompileOnlyClient::AotComputationInstance -> xla::CompileOnlyClient::AotXlaComputationInstance - xla::SessionModule -> xla::HloSnapshot PiperOrigin-RevId: 194874462 --- tensorflow/compiler/tf2xla/lib/BUILD | 26 ++-- tensorflow/compiler/tf2xla/lib/batch_dot.cc | 50 ++++---- tensorflow/compiler/tf2xla/lib/batch_dot.h | 12 +- tensorflow/compiler/tf2xla/lib/cholesky.cc | 50 ++++---- tensorflow/compiler/tf2xla/lib/cholesky.h | 9 +- tensorflow/compiler/tf2xla/lib/scatter.cc | 58 ++++----- tensorflow/compiler/tf2xla/lib/scatter.h | 18 ++- tensorflow/compiler/tf2xla/lib/triangular_solve.cc | 131 ++++++++++----------- tensorflow/compiler/tf2xla/lib/triangular_solve.h | 21 ++-- .../compiler/tf2xla/lib/triangular_solve_test.cc | 50 ++++---- tensorflow/compiler/tf2xla/lib/util.cc | 92 +++++++-------- tensorflow/compiler/tf2xla/lib/util.h | 67 ++++++----- tensorflow/compiler/tf2xla/lib/util_test.cc | 17 ++- tensorflow/compiler/tf2xla/lib/while_loop.cc | 52 ++++---- tensorflow/compiler/tf2xla/lib/while_loop.h | 29 +++-- 15 files changed, 332 insertions(+), 350 deletions(-) (limited to 'tensorflow/compiler/tf2xla/lib') diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 12fdfb605d..04ad3694a0 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -62,9 +62,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -82,8 +82,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -101,9 +101,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -122,8 +122,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -161,8 +161,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 798f0fa780..526694d5a0 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,24 +25,22 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x, bool conjugate_y) { - TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, - builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, - builder->GetShape(y)); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, + bool conjugate_y) { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); // Check that both tensors have the same number of dimensions. There must be // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(*x_shape), " vs. ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs. ", + xla::ShapeUtil::HumanString(y_shape)); } - const int ndims = xla::ShapeUtil::Rank(*x_shape); + const int ndims = xla::ShapeUtil::Rank(x_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to BatchedDot must have rank >= 2: ", ndims); @@ -52,46 +50,46 @@ xla::StatusOr BatchDot( // valid. std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " vs ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); } batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(*y_shape), + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), " transpose: ", transpose_y); } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(*x_shape) || - xla::ShapeUtil::HasZeroElements(*y_shape)) { + if (xla::ShapeUtil::HasZeroElements(x_shape) || + xla::ShapeUtil::HasZeroElements(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); } int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), dimensions); } - if (x_shape->element_type() == xla::C64 && conjugate_x) { + if (x_shape.element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && conjugate_y) { + if (y_shape.element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b230e885f1..1acc72033b 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x = false, bool conjugate_y = false); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 203365e2ab..83e7382786 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -47,23 +47,21 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(*a_shape); - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape->dimensions()), +xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, + const xla::XlaOp& a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { xla::Shape col_shape; xla::Shape row_shape; for (int64 d : major_dims) { @@ -72,12 +70,12 @@ xla::StatusOr CholeskyUnblocked( } row_shape.add_dimensions(1); row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape->element_type()); + row_shape.set_element_type(a_shape.element_type()); auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape->element_type()); + col_shape.set_element_type(a_shape.element_type()); auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); @@ -101,7 +99,7 @@ xla::StatusOr CholeskyUnblocked( TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, {i, i}, {1, 1})); // np.dot(row, np.swapaxes(row, -1, -2)) - xla::ComputationDataHandle diag_dot; + xla::XlaOp diag_dot; TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, /*transpose_x=*/false, /*transpose_y=*/true)); @@ -109,7 +107,7 @@ xla::StatusOr CholeskyUnblocked( // np.swapaxes(row, -1, -2))) auto l_ii = body_builder->Pow( body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape->element_type(), 0.5)); + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); @@ -140,7 +138,7 @@ xla::StatusOr CholeskyUnblocked( TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( body_builder, body_l, l_ii, {i, i})); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( @@ -152,22 +150,20 @@ xla::StatusOr CholeskyUnblocked( } // namespace -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to Cholesky must have rank >= 2: ", ndims); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { @@ -179,7 +175,7 @@ xla::StatusOr Cholesky( // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 17da8d8b22..20fca7969e 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -30,9 +30,8 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size = 256); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 45699233ea..d5a27abb25 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -30,24 +30,19 @@ limitations under the License. namespace tensorflow { -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, - builder->GetShape(buffer)); - TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, - builder->GetShape(updates)); - TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, - builder->GetShape(indices)); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); + TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); gtl::ArraySlice indices_dims = - xla::AsInt64Slice(indices_shape->dimensions()); + xla::AsInt64Slice(indices_shape.dimensions()); gtl::ArraySlice buffer_dims = - xla::AsInt64Slice(buffer_shape->dimensions()); + xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -55,12 +50,12 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", - xla::ShapeUtil::HumanString(*indices_shape), + xla::ShapeUtil::HumanString(indices_shape), ") must be <= the rank of the buffer (shape: ", - xla::ShapeUtil::HumanString(*buffer_shape), ")"); + xla::ShapeUtil::HumanString(buffer_shape), ")"); } indices_dims.pop_back(); } @@ -78,10 +73,10 @@ xla::StatusOr XlaScatter( // If any of the indexed dimensions are zero in the buffer, the update cannot // succeed since it updates a slice of size 1. for (int64 i = 0; i < num_index_dims; ++i) { - if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { - return errors::InvalidArgument( - "Scatter dimension ", i, " is of size zero in tensor with shape ", - xla::ShapeUtil::HumanString(*buffer_shape)); + if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { + return errors::InvalidArgument("Scatter dimension ", i, + " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(buffer_shape)); } } @@ -111,18 +106,17 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; auto buffer = loop_vars[2]; auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape->element_type())); + xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; + xla::XlaOp index; auto indices_offset = body_builder->Reshape(i, {1}); if (indices_are_vectors) { indices_offset = body_builder->Pad(indices_offset, zero_index, @@ -180,12 +174,12 @@ xla::StatusOr XlaScatter( // Apply the update. buffer = body_builder->DynamicUpdateSlice(buffer, update, index); - return std::vector{indices, updates, buffer}; + return std::vector{indices, updates, buffer}; }; - TF_ASSIGN_OR_RETURN( - auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), - body_fn, init, "scatter", builder)); + TF_ASSIGN_OR_RETURN(auto outputs, + XlaForEachIndex(num_indices, indices_shape.element_type(), + body_fn, init, "scatter", builder)); return outputs[2]; } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 41e6d3b195..87309e10ed 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { @@ -39,14 +39,12 @@ namespace tensorflow { // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 9bf5821b54..d0279d4412 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -29,21 +29,20 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(*a_shape), " vs. ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - const int ndims = xla::ShapeUtil::Rank(*a_shape); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to TriangularSolve must have rank >= 2: ", ndims); @@ -51,30 +50,30 @@ xla::StatusOr TriangularSolve( // The batch dimensions must be equal. std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); - int64 b_size = b_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { return errors::InvalidArgument( "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(*a_shape, -1) != - xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { @@ -85,24 +84,23 @@ xla::StatusOr TriangularSolve( // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) { + auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a; return perform_conj ? builder->Conj(x) : x; }; - std::map base_computations; + std::map base_computations; auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::Computation& computation = base_computations[k]; + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( + std::unique_ptr sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); auto a_param = sub->Parameter( 0, xla::ShapeUtil::MakeShape( - b_shape->element_type(), + b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, {k, k})), "a"); @@ -115,7 +113,7 @@ xla::StatusOr TriangularSolve( auto b_param = sub->Parameter( 1, xla::ShapeUtil::MakeShape( - b_shape->element_type(), + b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), "b"); @@ -142,7 +140,7 @@ xla::StatusOr TriangularSolve( return &computation; }; - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); // Right-looking blocked triangular solve. // For an explanation of the algorithm, see the TRSM discussion in: @@ -165,9 +163,9 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { @@ -181,7 +179,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) if (i + k < n) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); @@ -215,9 +213,9 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { @@ -231,7 +229,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) if (i + k < m) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); @@ -264,9 +262,9 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { @@ -280,7 +278,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -314,9 +312,9 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { @@ -330,7 +328,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -356,26 +354,25 @@ xla::StatusOr TriangularSolve( return output; } -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); batch_dimensions.push_back(a_size); } - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) { + auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a; return perform_conj ? builder->Conj(x) : x; }; @@ -387,7 +384,7 @@ xla::StatusOr TriangularSolveLeftLooking( // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] // else: // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); { auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, @@ -408,11 +405,11 @@ xla::StatusOr TriangularSolveLeftLooking( // The loop iteration counter is a scalar, incremented each iteration. xla::ShapeUtil::MakeShape(xla::S32, {}), // The output has the shape of b, with one row updated each iteration. - *b_shape, + b_shape, // The coefficient matrix a is a loop invariant. - *a_shape, + a_shape, // The right-hand-side matrix b is a loop invariant. - *b_shape}; + b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); auto init = builder->Tuple({init_i, output, a, b}); @@ -421,7 +418,7 @@ xla::StatusOr TriangularSolveLeftLooking( // def cond_fun(loop_carry): // i, output, a, b = loop_carry // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); { auto i = condb->GetTupleElement( @@ -451,7 +448,7 @@ xla::StatusOr TriangularSolveLeftLooking( // return (i + 1, output, a, b) // We have to do some extra FLOPs propagating zeros in the matrix multiply // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); { auto input_tuple = bodyb->Parameter(0, tuple_shape, @@ -475,7 +472,7 @@ xla::StatusOr TriangularSolveLeftLooking( // But since we can't have intermediate array sizes depend on the loop // counter, we instead exploit the fact that we initialized the output to // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::ComputationDataHandle a_row; + xla::XlaOp a_row; if (transpose_a) { TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, {zero, i}, {m, 1})); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index e32223bfdd..fd8f2489d1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -57,14 +57,17 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 256); +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 6617070629..87ea4763f7 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -80,9 +80,9 @@ xla::Array2D AValsFull() { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, @@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { } XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 31d823ca33..cc7b13571c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -27,15 +27,14 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape) { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value) { +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value) { switch (type) { case xla::F16: return builder->ConstantR0(static_cast(value)); @@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, - int64 value) { +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value) { xla::Literal literal; switch (type) { case xla::U8: @@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, return builder->ConstantLiteral(literal); } -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end) { +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end) { TF_RET_CHECK(start.size() == end.size()); int64 n_minor_dims = start.size(); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - n_minor_dims); @@ -140,7 +139,7 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } -std::vector PrependMajorDims(xla::ComputationBuilder* builder, +std::vector PrependMajorDims(xla::XlaBuilder* builder, const gtl::ArraySlice& major_dims, const gtl::ArraySlice& indices) { std::vector output(indices.size() + major_dims.size()); @@ -149,16 +148,16 @@ std::vector PrependMajorDims(xla::ComputationBuilder* builder, return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts, +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - sizes.size()); TF_ASSIGN_OR_RETURN(auto padded_starts, @@ -167,27 +166,29 @@ xla::StatusOr DynamicSliceInMinorDims( return builder->DynamicSlice(x, padded_starts, padded_sizes); } -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); auto start_constant = builder->ConstantR1(start_as_int32); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); - TF_ASSIGN_OR_RETURN(std::unique_ptr start_constant_shape, + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, builder->GetShape(start_constant)); const int64 start_length = - xla::ShapeUtil::GetDimension(*start_constant_shape, -1); + xla::ShapeUtil::GetDimension(start_constant_shape, -1); TF_RET_CHECK(start_length == n_dims); return builder->DynamicUpdateSlice(x, update, start_constant); } -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -196,22 +197,21 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, - const std::vector& starts) { +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts) { TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(builder, x, starts)); return builder->DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); auto zero = builder->Reshape(builder->ConstantR0(0), {1}); - std::vector padded_starts(n_dims, zero); + std::vector padded_starts(n_dims, zero); for (int i = 0; i < starts.size(); ++i) { padded_starts[n_dims - starts.size() + i] = builder->Reshape(starts[i], {1}); @@ -219,10 +219,10 @@ xla::StatusOr PrependZerosInMajorDims( return builder->ConcatInDim(padded_starts, 0); } -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b684123f13..3df44ef035 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,75 +16,74 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape); +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value); +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value); // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::ComputationDataHandle PrependZerosInMajorDims( - xla::ComputationBuilder* builder, - gtl::ArraySlice starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, + gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, int64 value); +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value); // Builds a vector of zeros of length rank(x) with the last two values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts); +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end); +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end); // Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::ComputationBuilder* builder, +std::vector PrependMajorDims(xla::XlaBuilder* builder, const gtl::ArraySlice& major_dims, const gtl::ArraySlice& indices); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts, - const gtl::ArraySlice& sizes); +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, - const std::vector& starts); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index b6bd33af2e..265b39402c 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -65,9 +64,9 @@ xla::Array3D BatchedAValsFull() { } XLA_TEST_F(UtilTest, Simple2dLookup) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, x, y; + xla::XlaOp a, x, y; auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); @@ -80,9 +79,9 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { } XLA_TEST_F(UtilTest, Simple3dLookup) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, index; + xla::XlaOp a, index; auto a_data = CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); @@ -97,9 +96,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { } XLA_TEST_F(UtilTest, SimpleSliceUpdate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b, x, y; + xla::XlaOp a, b, x, y; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); @@ -117,11 +116,11 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { } XLA_TEST_F(UtilTest, RowBatchDot) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); int n = 4; - xla::ComputationDataHandle a, row, index; + xla::XlaOp a, row, index; auto a_data = CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 495d9c6078..09ce594930 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -20,24 +20,24 @@ limitations under the License. namespace tensorflow { -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::ComputationDataHandle& input : initial_values) { + for (const xla::XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); - var_shapes.push_back(std::move(*shape)); + var_shapes.push_back(std::move(shape)); } xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, - xla::ComputationBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](xla::XlaOp tuple, int arity, + xla::XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { elements[i] = builder->GetTupleElement(tuple, i); } @@ -45,20 +45,20 @@ xla::StatusOr> XlaWhileLoop( }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); - TF_ASSIGN_OR_RETURN( - auto result, + TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), - cond_builder.get())); + cond_builder.get()) + .status()); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); @@ -78,38 +78,38 @@ xla::StatusOr> XlaWhileLoop( return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { - auto while_cond_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* cond_builder) - -> xla::StatusOr { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { + auto while_cond_fn = + [&](gtl::ArraySlice values, + xla::XlaBuilder* cond_builder) -> xla::StatusOr { return cond_builder->Lt( values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { - xla::ComputationDataHandle iteration = values[0]; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(body_builder->Add( iteration, body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); values.push_back( builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 2e67a0c99b..5b6684c995 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -29,14 +29,14 @@ namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function(gtl::ArraySlice, + xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function>( + gtl::ArraySlice, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -47,27 +47,26 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::ComputationDataHandle, gtl::ArraySlice, - xla::ComputationBuilder*)> +typedef std::function>( + xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); } // namespace tensorflow -- cgit v1.2.3