aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 16491002b4..c810456f94 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -165,9 +166,8 @@ class ReverseSequenceOp : public XlaOpKernel {
auto output = xla::GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
- xla::XlaOp iota;
- OP_REQUIRES_OK(
- context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
+ xla::XlaOp iota =
+ xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len);
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});