diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/cholesky.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/cholesky.cc | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index a90178c7d9..cc840de393 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +59,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = Zeros(builder, a_shape); + xla::XlaOp l = xla::ZerosLike(a); // Construct the for loop body to iterate over rows. auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars, @@ -73,12 +74,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { row_shape.add_dimensions(1); row_shape.add_dimensions(n); row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); + auto mask_zeros_row = xla::Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); + auto mask_zeros_col = xla::Zeros(body_builder, col_shape); std::vector<int32> mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); @@ -170,7 +171,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); + xla::XlaOp l = xla::ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { |