aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 17:41:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 17:43:59 -0700
commit45bafe9a3589fc735c22c3c703f8689ea9c1e71e (patch)
treee39723521a1ca68e9c2c74d1a9d3ac5ef2e8abc4 /tensorflow/compiler/tf2xla/lib
parentc89a1d9605427d74079774af7da37933f9ca153c (diff)
[XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compiler/aot:
- xla::ComputationBuilder -> xla::XlaBuilder - xla::ComputationDataHandle -> xla::XlaOp - xla::Computation -> xla::XlaComputation - xla::CompileOnlyClient::AotComputationInstance -> xla::CompileOnlyClient::AotXlaComputationInstance - xla::SessionModule -> xla::HloSnapshot PiperOrigin-RevId: 194874462
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD26
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h12
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h9
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc58
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h18
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc131
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h21
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc92
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h67
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc17
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc52
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.h29
15 files changed, 332 insertions, 350 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 12fdfb605d..04ad3694a0 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -62,9 +62,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -82,8 +82,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -101,9 +101,9 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -122,8 +122,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -161,8 +161,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 798f0fa780..526694d5a0 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -25,24 +25,22 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x, bool conjugate_y) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
- builder->GetShape(x));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
- builder->GetShape(y));
+xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x,
+ bool conjugate_y) {
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
// Check that both tensors have the same number of dimensions. There must be
// at least two (the batch dimensions can be empty).
- if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
+ if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
return errors::InvalidArgument(
"Arguments to BatchedDot have different ranks: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs. ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs. ",
+ xla::ShapeUtil::HumanString(y_shape));
}
- const int ndims = xla::ShapeUtil::Rank(*x_shape);
+ const int ndims = xla::ShapeUtil::Rank(x_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to BatchedDot must have rank >= 2: ", ndims);
@@ -52,46 +50,46 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
// valid.
std::vector<int64> batch_dimension_numbers;
for (int i = 0; i < ndims - 2; ++i) {
- if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
+ if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
return errors::InvalidArgument(
"Dimension ", i, " of inputs to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs ",
+ xla::ShapeUtil::HumanString(y_shape));
}
batch_dimension_numbers.push_back(i);
}
int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
- if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
+ if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
return errors::InvalidArgument(
"Dimensions ", x_inner_dim, " and ", y_inner_dim,
" of arguments to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
- " vs. ", xla::ShapeUtil::HumanString(*y_shape),
+ xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(y_shape),
" transpose: ", transpose_y);
}
// Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
- xla::ShapeUtil::HasZeroElements(*y_shape)) {
+ if (xla::ShapeUtil::HasZeroElements(x_shape) ||
+ xla::ShapeUtil::HasZeroElements(y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
- dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
+ dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
}
int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape->dimensions(x_outer_dim));
- dimensions.push_back(y_shape->dimensions(y_outer_dim));
+ dimensions.push_back(x_shape.dimensions(x_outer_dim));
+ dimensions.push_back(y_shape.dimensions(y_outer_dim));
return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
+ builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())),
dimensions);
}
- if (x_shape->element_type() == xla::C64 && conjugate_x) {
+ if (x_shape.element_type() == xla::C64 && conjugate_x) {
x = builder->Conj(x);
}
- if (y_shape->element_type() == xla::C64 && conjugate_y) {
+ if (y_shape.element_type() == xla::C64 && conjugate_y) {
y = builder->Conj(y);
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index b230e885f1..1acc72033b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -43,10 +43,10 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::StatusOr<xla::ComputationDataHandle> BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x = false, bool conjugate_y = false);
+xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x = false,
+ bool conjugate_y = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 203365e2ab..83e7382786 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -47,23 +47,21 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- const int n_dims = xla::ShapeUtil::Rank(*a_shape);
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape->dimensions()),
+xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
+ const xla::XlaOp& a) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - 2);
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
// Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::Shape col_shape;
xla::Shape row_shape;
for (int64 d : major_dims) {
@@ -72,12 +70,12 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
}
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
- row_shape.set_element_type(a_shape->element_type());
+ row_shape.set_element_type(a_shape.element_type());
auto mask_zeros_row = Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
- col_shape.set_element_type(a_shape->element_type());
+ col_shape.set_element_type(a_shape.element_type());
auto mask_zeros_col = Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n);
@@ -101,7 +99,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
{i, i}, {1, 1}));
// np.dot(row, np.swapaxes(row, -1, -2))
- xla::ComputationDataHandle diag_dot;
+ xla::XlaOp diag_dot;
TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
/*transpose_x=*/false,
/*transpose_y=*/true));
@@ -109,7 +107,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
// np.swapaxes(row, -1, -2)))
auto l_ii = body_builder->Pow(
body_builder->Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape->element_type(), 0.5));
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
@@ -140,7 +138,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
body_builder, body_l, l_ii, {i, i}));
- return std::vector<xla::ComputationDataHandle>{body_a, body_l};
+ return std::vector<xla::XlaOp>{body_a, body_l};
};
TF_ASSIGN_OR_RETURN(
@@ -152,22 +150,20 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
} // namespace
-xla::StatusOr<xla::ComputationDataHandle> Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- const int ndims = xla::ShapeUtil::Rank(*a_shape);
+xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to Cholesky must have rank >= 2: ", ndims);
}
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"Arguments to Cholesky must be square matrices: ",
- xla::ShapeUtil::HumanString(*a_shape));
+ xla::ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
@@ -179,7 +175,7 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
// execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 17da8d8b22..20fca7969e 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -30,9 +30,8 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::StatusOr<xla::ComputationDataHandle> Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size = 256);
+xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
+ int64 block_size = 256);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 45699233ea..d5a27abb25 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -30,24 +30,19 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
- const xla::ComputationDataHandle& buffer,
- const xla::ComputationDataHandle& updates,
- const xla::ComputationDataHandle& indices, bool indices_are_vectors,
- const std::function<xla::ComputationDataHandle(
- xla::ComputationDataHandle, xla::ComputationDataHandle,
- xla::ComputationBuilder*)>& combiner,
- xla::ComputationBuilder* builder) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape,
- builder->GetShape(buffer));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape,
- builder->GetShape(updates));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape,
- builder->GetShape(indices));
+xla::StatusOr<xla::XlaOp> XlaScatter(
+ const xla::XlaOp& buffer, const xla::XlaOp& updates,
+ const xla::XlaOp& indices, bool indices_are_vectors,
+ const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
+ combiner,
+ xla::XlaBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
+ TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
gtl::ArraySlice<int64> indices_dims =
- xla::AsInt64Slice(indices_shape->dimensions());
+ xla::AsInt64Slice(indices_shape.dimensions());
gtl::ArraySlice<int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape->dimensions());
+ xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -55,12 +50,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
- if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) {
+ if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
- xla::ShapeUtil::HumanString(*indices_shape),
+ xla::ShapeUtil::HumanString(indices_shape),
") must be <= the rank of the buffer (shape: ",
- xla::ShapeUtil::HumanString(*buffer_shape), ")");
+ xla::ShapeUtil::HumanString(buffer_shape), ")");
}
indices_dims.pop_back();
}
@@ -78,10 +73,10 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// If any of the indexed dimensions are zero in the buffer, the update cannot
// succeed since it updates a slice of size 1.
for (int64 i = 0; i < num_index_dims; ++i) {
- if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) {
- return errors::InvalidArgument(
- "Scatter dimension ", i, " is of size zero in tensor with shape ",
- xla::ShapeUtil::HumanString(*buffer_shape));
+ if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) {
+ return errors::InvalidArgument("Scatter dimension ", i,
+ " is of size zero in tensor with shape ",
+ xla::ShapeUtil::HumanString(buffer_shape));
}
}
@@ -111,18 +106,17 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* body_builder) {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
auto zero_index = body_builder->ConstantLiteral(
- xla::Literal::Zero(indices_shape->element_type()));
+ xla::Literal::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
- xla::ComputationDataHandle index;
+ xla::XlaOp index;
auto indices_offset = body_builder->Reshape(i, {1});
if (indices_are_vectors) {
indices_offset = body_builder->Pad(indices_offset, zero_index,
@@ -180,12 +174,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// Apply the update.
buffer = body_builder->DynamicUpdateSlice(buffer, update, index);
- return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
+ return std::vector<xla::XlaOp>{indices, updates, buffer};
};
- TF_ASSIGN_OR_RETURN(
- auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
- body_fn, init, "scatter", builder));
+ TF_ASSIGN_OR_RETURN(auto outputs,
+ XlaForEachIndex(num_indices, indices_shape.element_type(),
+ body_fn, init, "scatter", builder));
return outputs[2];
}
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 41e6d3b195..87309e10ed 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <functional>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
@@ -39,14 +39,12 @@ namespace tensorflow {
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
// existing values. The order of updates is implementation-defined.
-xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
- const xla::ComputationDataHandle& buffer,
- const xla::ComputationDataHandle& updates,
- const xla::ComputationDataHandle& indices, bool indices_are_vectors,
- const std::function<xla::ComputationDataHandle(
- xla::ComputationDataHandle, xla::ComputationDataHandle,
- xla::ComputationBuilder*)>& combiner,
- xla::ComputationBuilder* builder);
+xla::StatusOr<xla::XlaOp> XlaScatter(
+ const xla::XlaOp& buffer, const xla::XlaOp& updates,
+ const xla::XlaOp& indices, bool indices_are_vectors,
+ const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
+ combiner,
+ xla::XlaBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 9bf5821b54..d0279d4412 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -29,21 +29,20 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
- bool conjugate_a, int64 block_size) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
- builder->GetShape(b));
- if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) {
+xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
+ const xla::XlaOp& a, xla::XlaOp b,
+ bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
return errors::InvalidArgument(
"Arguments to TriangularSolve have different ranks: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs. ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs. ",
+ xla::ShapeUtil::HumanString(b_shape));
}
- const int ndims = xla::ShapeUtil::Rank(*a_shape);
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to TriangularSolve must have rank >= 2: ", ndims);
@@ -51,30 +50,30 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// The batch dimensions must be equal.
std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape->dimensions(i);
- int64 b_size = b_shape->dimensions(i);
+ int64 a_size = a_shape.dimensions(i);
+ int64 b_size = b_shape.dimensions(i);
if (a_size != b_size) {
return errors::InvalidArgument(
"Batch dimensions of arguments to TriangularSolve must be equal: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
batch_dimensions.push_back(a_size);
}
- if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
- xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
+ xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"The 'a' arguments to TriangularSolve must be square matrices: ",
- xla::ShapeUtil::HumanString(*a_shape));
+ xla::ShapeUtil::HumanString(a_shape));
}
- const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
- if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
return errors::InvalidArgument(
"Arguments to TriangularSolve have incompatible matrix shapes: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
if (block_size < 1) {
@@ -85,24 +84,23 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
- auto maybe_conj = [&](xla::ComputationBuilder* builder,
- xla::ComputationDataHandle x) {
- auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
+ auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
+ auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
- std::map<int, xla::Computation> base_computations;
+ std::map<int, xla::XlaComputation> base_computations;
auto get_base_triangular_solve =
- [&](int k) -> xla::StatusOr<xla::Computation*> {
- xla::Computation& computation = base_computations[k];
+ [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
+ xla::XlaComputation& computation = base_computations[k];
if (computation.IsNull()) {
- std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
+ std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
tensorflow::strings::StrCat("trsm_base_", k));
auto a_param = sub->Parameter(
0,
xla::ShapeUtil::MakeShape(
- b_shape->element_type(),
+ b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
"a");
@@ -115,7 +113,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
auto b_param = sub->Parameter(
1,
xla::ShapeUtil::MakeShape(
- b_shape->element_type(),
+ b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
"b");
@@ -142,7 +140,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
return &computation;
};
- xla::ComputationDataHandle output = Zeros(builder, *b_shape);
+ xla::XlaOp output = Zeros(builder, b_shape);
// Right-looking blocked triangular solve.
// For an explanation of the algorithm, see the TRSM discussion in:
@@ -165,9 +163,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -181,7 +179,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
if (i + k < n) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
@@ -215,9 +213,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -231,7 +229,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
if (i + k < m) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
@@ -264,9 +262,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -280,7 +278,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
if (i - k >= 0) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
@@ -314,9 +312,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -330,7 +328,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
if (i - k >= 0) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
@@ -356,26 +354,25 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
return output;
}
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
- builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
+xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
+ const xla::XlaOp& a,
+ const xla::XlaOp& b,
+ bool transpose_a,
+ bool conjugate_a) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape->dimensions(i);
+ int64 a_size = a_shape.dimensions(i);
batch_dimensions.push_back(a_size);
}
- auto maybe_conj = [&](xla::ComputationBuilder* builder,
- xla::ComputationDataHandle x) {
- auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
+ auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
+ auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
@@ -387,7 +384,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
// else:
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::ComputationDataHandle output = Zeros(builder, *b_shape);
+ xla::XlaOp output = Zeros(builder, b_shape);
{
auto i = transpose_a ? m - 1 : 0;
TF_ASSIGN_OR_RETURN(auto a_slice,
@@ -408,11 +405,11 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// The loop iteration counter is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(xla::S32, {}),
// The output has the shape of b, with one row updated each iteration.
- *b_shape,
+ b_shape,
// The coefficient matrix a is a loop invariant.
- *a_shape,
+ a_shape,
// The right-hand-side matrix b is a loop invariant.
- *b_shape};
+ b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
auto init = builder->Tuple({init_i, output, a, b});
@@ -421,7 +418,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// def cond_fun(loop_carry):
// i, output, a, b = loop_carry
// return i >= 0 if transpose_a else i < m
- std::unique_ptr<xla::ComputationBuilder> condb =
+ std::unique_ptr<xla::XlaBuilder> condb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
{
auto i = condb->GetTupleElement(
@@ -451,7 +448,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// return (i + 1, output, a, b)
// We have to do some extra FLOPs propagating zeros in the matrix multiply
// because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::ComputationBuilder> bodyb =
+ std::unique_ptr<xla::XlaBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
{
auto input_tuple = bodyb->Parameter(0, tuple_shape,
@@ -475,7 +472,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// But since we can't have intermediate array sizes depend on the loop
// counter, we instead exploit the fact that we initialized the output to
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
- xla::ComputationDataHandle a_row;
+ xla::XlaOp a_row;
if (transpose_a) {
TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
{zero, i}, {m, 1}));
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index e32223bfdd..fd8f2489d1 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -57,14 +57,17 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
- bool conjugate_a, int64 block_size = 256);
+xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
+ const xla::XlaOp& a, xla::XlaOp b,
+ bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a,
+ int64 block_size = 256);
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a);
+xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
+ const xla::XlaOp& a,
+ const xla::XlaOp& b,
+ bool transpose_a,
+ bool conjugate_a);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index 6617070629..87ea4763f7 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -80,9 +80,9 @@ xla::Array2D<float> AValsFull() {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
@@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
@@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
}
XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
@@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
}
XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 31d823ca33..cc7b13571c 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -27,15 +27,14 @@ limitations under the License.
namespace tensorflow {
-xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- const xla::Shape& shape) {
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
return builder->Broadcast(
builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
}
-xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, double value) {
+xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ double value) {
switch (type) {
case xla::F16:
return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
@@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
}
}
-xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type,
- int64 value) {
+xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ int64 value) {
xla::Literal literal;
switch (type) {
case xla::U8:
@@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
return builder->ConstantLiteral(literal);
}
-xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
+xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end) {
TF_RET_CHECK(start.size() == end.size());
int64 n_minor_dims = start.size();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - n_minor_dims);
@@ -140,7 +139,7 @@ xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
return builder->Slice(x, padded_start, padded_end, strides);
}
-std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
const gtl::ArraySlice<int64>& major_dims,
const gtl::ArraySlice<int64>& indices) {
std::vector<int64> output(indices.size() + major_dims.size());
@@ -149,16 +148,16 @@ std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
return output;
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts,
+xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts,
const gtl::ArraySlice<int64>& sizes) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - sizes.size());
TF_ASSIGN_OR_RETURN(auto padded_starts,
@@ -167,27 +166,29 @@ xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
return builder->DynamicSlice(x, padded_starts, padded_sizes);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
+xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = builder->ConstantR1<int32>(start_as_int32);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> start_constant_shape,
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
- xla::ShapeUtil::GetDimension(*start_constant_shape, -1);
+ xla::ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
return builder->DynamicUpdateSlice(x, update, start_constant);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
@@ -196,22 +197,21 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
return UpdateSlice(builder, x, update, padded_start);
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update,
- const std::vector<xla::ComputationDataHandle>& starts) {
+xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
+ const std::vector<xla::XlaOp>& starts) {
TF_ASSIGN_OR_RETURN(auto padded_starts,
PrependZerosInMajorDims(builder, x, starts));
return builder->DynamicUpdateSlice(x, update, padded_starts);
}
-xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
- std::vector<xla::ComputationDataHandle> padded_starts(n_dims, zero);
+ std::vector<xla::XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
padded_starts[n_dims - starts.size() + i] =
builder->Reshape(starts[i], {1});
@@ -219,10 +219,10 @@ xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
return builder->ConcatInDim(padded_starts, 0);
}
-xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index b684123f13..3df44ef035 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -16,75 +16,74 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
// Returns a zero-filled tensor with shape `shape`.
-xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- const xla::Shape& shape);
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape);
// Returns a floating point scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
-xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, double value);
+xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ double value);
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
-xla::ComputationDataHandle PrependZerosInMajorDims(
- xla::ComputationBuilder* builder,
- gtl::ArraySlice<xla::ComputationDataHandle> starts);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
-xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, int64 value);
+xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ int64 value);
// Builds a vector of zeros of length rank(x) with the last two values being
// those in `starts`.
-xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts);
+xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end);
+xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end);
// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
-std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
const gtl::ArraySlice<int64>& major_dims,
const gtl::ArraySlice<int64>& indices);
// Performs a dynamic slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts,
- const gtl::ArraySlice<int64>& sizes);
+xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts, const gtl::ArraySlice<int64>& sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
-xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
-xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start);
-xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update,
- const std::vector<xla::ComputationDataHandle>& starts);
+xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
+ const std::vector<xla::XlaOp>& starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
-xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
+xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index b6bd33af2e..265b39402c 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -65,9 +64,9 @@ xla::Array3D<float> BatchedAValsFull() {
}
XLA_TEST_F(UtilTest, Simple2dLookup) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, x, y;
+ xla::XlaOp a, x, y;
auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
@@ -80,9 +79,9 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
}
XLA_TEST_F(UtilTest, Simple3dLookup) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, index;
+ xla::XlaOp a, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
@@ -97,9 +96,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
}
XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b, x, y;
+ xla::XlaOp a, b, x, y;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
@@ -117,11 +116,11 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
}
XLA_TEST_F(UtilTest, RowBatchDot) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
int n = 4;
- xla::ComputationDataHandle a, row, index;
+ xla::XlaOp a, row, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 495d9c6078..09ce594930 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -20,24 +20,24 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
var_shapes.reserve(arity);
- for (const xla::ComputationDataHandle& input : initial_values) {
+ for (const xla::XlaOp& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
- var_shapes.push_back(std::move(*shape));
+ var_shapes.push_back(std::move(shape));
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts.
- auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity,
- xla::ComputationBuilder* builder) {
- std::vector<xla::ComputationDataHandle> elements(arity);
+ auto unpack_tuple = [](xla::XlaOp tuple, int arity,
+ xla::XlaBuilder* builder) {
+ std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
elements[i] = builder->GetTupleElement(tuple, i);
}
@@ -45,20 +45,20 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
};
// Build the condition.
- std::unique_ptr<xla::ComputationBuilder> cond_builder =
+ std::unique_ptr<xla::XlaBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
- TF_ASSIGN_OR_RETURN(
- auto result,
+ TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
- cond_builder.get()));
+ cond_builder.get())
+ .status());
}
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body.
- std::unique_ptr<xla::ComputationBuilder> body_builder =
+ std::unique_ptr<xla::XlaBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
@@ -78,38 +78,38 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
return unpack_tuple(outputs, arity, builder);
}
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
- auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* cond_builder)
- -> xla::StatusOr<xla::ComputationDataHandle> {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
+ auto while_cond_fn =
+ [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return cond_builder->Lt(
values[0],
IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
};
- auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
- xla::ComputationDataHandle iteration = values[0];
+ auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ xla::XlaOp iteration = values[0];
- std::vector<xla::ComputationDataHandle> updated_values;
+ std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(body_builder->Add(
iteration,
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
values.remove_prefix(1);
- TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
+ TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end());
return updated_values;
};
- std::vector<xla::ComputationDataHandle> values;
+ std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
values.push_back(
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 2e67a0c99b..5b6684c995 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,8 +19,8 @@ limitations under the License.
#include <functional>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -29,14 +29,14 @@ namespace tensorflow {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
-typedef std::function<xla::StatusOr<xla::ComputationDataHandle>(
- gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<xla::XlaOp>(gtl::ArraySlice<xla::XlaOp>,
+ xla::XlaBuilder*)>
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
-typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
- gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
+ gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
LoopBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
@@ -47,27 +47,26 @@ typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
// init: (a, b, c)
// )
// 'name' is a descriptive name for the loop.
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder);
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
//
// The body function (ForEachIndexBodyFunction) takes as input a pair of
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
-typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
- xla::ComputationDataHandle, gtl::ArraySlice<xla::ComputationDataHandle>,
- xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
+ xla::XlaOp, gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
ForEachIndexBodyFunction;
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder);
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder);
} // namespace tensorflow