diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 16491002b4..c810456f94 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -165,9 +166,8 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector<int64> dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); |