From b1061347ce4c9757c1e2c420bf672dac5270034c Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Thu, 17 May 2018 13:37:57 -0700 Subject: [TF:XLA] Do not rely on implementation-defined semantics of DynamicSlice. ReverseSequence relies on DynamicSlice wrapping around, which is implementation-defined behavior, and is not guaranteed. Pad the input instead. PiperOrigin-RevId: 197043307 --- .../compiler/tf2xla/kernels/reverse_sequence_op.cc | 48 ++++++++++++++-------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 0ed4c4707d..5d1c052684 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( -- cgit v1.2.3