aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-05-17 13:37:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 13:44:32 -0700
commitb1061347ce4c9757c1e2c420bf672dac5270034c (patch)
tree6214406fe898098597de56c93d5c77d0a4473e4d
parent7232a906caa549a108912999230ef0ec790b4dbd (diff)
[TF:XLA] Do not rely on implementation-defined semantics of DynamicSlice.
ReverseSequence relies on DynamicSlice wrapping around, which is implementation-defined behavior, and is not guaranteed. Pad the input instead. PiperOrigin-RevId: 197043307
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc48
1 files changed, 31 insertions, 17 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 0ed4c4707d..5d1c052684 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel {
seq_lens, body_builder->Reshape(i, {1}), {1});
// Indices is the offset of the batch element in the input.
- auto indices = body_builder->Broadcast(
+ auto batch_element_indices = body_builder->Broadcast(
XlaHelpers::Zero(body_builder.get(), seq_lens_type),
{input_shape.dims()});
- indices = body_builder->DynamicUpdateSlice(
- indices, body_builder->Reshape(i, {1}),
+ batch_element_indices = body_builder->DynamicUpdateSlice(
+ batch_element_indices, body_builder->Reshape(i, {1}),
body_builder->Reshape(
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
batch_dim_),
{1}));
- // slice_indices is the offset of the start of the reversed sequence in
- // the input.
- auto slice_indices = body_builder->DynamicUpdateSlice(
- indices,
+ // Slice out the current batch element and pad it out in the sequence
+ // dimension.
+ TensorShape slice_shape = input_shape;
+ slice_shape.set_dim(batch_dim_, 1);
+ slice_shape.set_dim(seq_dim_, max_seq_len);
+ auto slice = body_builder->DynamicSlice(output, batch_element_indices,
+ slice_shape.dim_sizes());
+ auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
+ padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
+ slice_shape.dim_size(seq_dim_));
+ slice = body_builder->Pad(
+ slice, XlaHelpers::Zero(body_builder.get(), input_type),
+ padding_config);
+
+ // Now slice out the reversed sequence from its actual start.
+ // sequence_start_indices is the offset of the start of the reversed
+ // sequence in the input. The slice will go into the padding, however, we
+ // will mask off these elements and replace them with elements from the
+ // original input so their values do not matter.
+ auto sequence_start_indices = body_builder->Broadcast(
+ XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {slice_shape.dims()});
+ sequence_start_indices = body_builder->DynamicUpdateSlice(
+ sequence_start_indices,
body_builder->Sub(XlaHelpers::IntegerLiteral(
body_builder.get(), seq_lens_type, max_seq_len),
seq_len),
@@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel {
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
seq_dim_),
{1}));
-
- // Slice out the reversed sequence. The slice will overflow the end of the
- // sequence, and the contents of the overflow are implementation-defined.
- // However, we will mask off these elements and replace them with elements
- // from the original input so their values do not matter.
- TensorShape slice_shape = input_shape;
- slice_shape.set_dim(batch_dim_, 1);
- auto slice = body_builder->DynamicSlice(output, slice_indices,
- slice_shape.dim_sizes());
+ slice = body_builder->DynamicSlice(slice, sequence_start_indices,
+ slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
- output = body_builder->DynamicUpdateSlice(output, slice, indices);
+ output = body_builder->DynamicUpdateSlice(output, slice,
+ batch_element_indices);
body_builder->Tuple(
{body_builder->Add(