aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc7
1 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index 7bf1894ea0..e2ab4b83cf 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -62,10 +63,8 @@ class MatrixSetDiagOp : public XlaOpKernel {
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
+ xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m);
+ xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n);
auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
/*broadcast_dimensions=*/{0});
indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());