aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/cholesky.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/cholesky.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index a90178c7d9..cc840de393 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -58,7 +59,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
/*pos=*/0,
/*len=*/n_dims - 2);
- xla::XlaOp l = Zeros(builder, a_shape);
+ xla::XlaOp l = xla::ZerosLike(a);
// Construct the for loop body to iterate over rows.
auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
@@ -73,12 +74,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
row_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_row = Zeros(body_builder, row_shape);
+ auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
col_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_col = Zeros(body_builder, col_shape);
+ auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n);
std::iota(mask_vector.begin(), mask_vector.end(), 0);
@@ -170,7 +171,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::XlaOp l = Zeros(builder, a_shape);
+ xla::XlaOp l = xla::ZerosLike(a);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {