aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index 9d3575e331..e06c87db7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -51,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel {
xla::XlaOp num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
+ xla::PrimitiveType index_xla_type = context->input_xla_type(1);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
@@ -59,11 +61,8 @@ class MatrixBandPartOp : public XlaOpKernel {
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
-
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
+ xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m);
+ xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n);
auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m,
/*broadcast_dimensions=*/{0});