diff options
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/compiler/tests/reverse_sequence_op_test.py | 93 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc | 182 |
4 files changed, 288 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 314f5506b1..a3a82df9ad 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -438,6 +438,18 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "reverse_sequence_op_test", + size = "small", + srcs = ["reverse_sequence_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "rmsprop_test", size = "small", srcs = ["rmsprop_test.py"], diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py new file mode 100644 index 0000000000..1a5d05094e --- /dev/null +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.reverse_sequence_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReverseSequenceTest(XLATestCase): + + def _testReverseSequence(self, + x, + batch_axis, + seq_axis, + seq_lengths, + truth, + expected_err_re=None): + with self.test_session(): + p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) + lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) + with self.test_scope(): + ans = array_ops.reverse_sequence( + p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) + if expected_err_re is None: + tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) + self.assertAllClose(tf_ans, truth, atol=1e-10) + else: + with self.assertRaisesOpError(expected_err_re): + ans.eval(feed_dict={p: x, lengths: seq_lengths}) + + def testSimple(self): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) + self._testReverseSequence( + x, + batch_axis=0, + seq_axis=1, + seq_lengths=np.array([1, 3, 2], np.int32), + truth=expected) + + def _testBasic(self, dtype, len_dtype): + x = np.asarray( + [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], + [[17, 18, 19, 20], [21, 22, 23, 24]]], + dtype=dtype) + x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + # reverse dim 2 up to (0:3, none, 0:4) along dim=0 + seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) + + truth_orig = np.asarray( + [ + [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 + [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none + [[20, 19, 18, 17], [24, 23, 22, 21]] + ], # reverse 0:4 (all) + dtype=dtype) + truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) + truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) + batch_axis = 2 + self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) + + def testSeqLength(self): + for dtype in self.all_types: + for seq_dtype in self.int_types: + self._testBasic(dtype, seq_dtype) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5e1b01878b..a7e00cb12f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -58,6 +58,7 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "reverse_sequence_op.cc", "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc new file mode 100644 index 0000000000..6bc5d3adb0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ReverseSequenceOp : public XlaOpKernel { + public: + explicit ReverseSequenceOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); + OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape seq_lens_shape = context->InputShape(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), + errors::InvalidArgument("seq_lens input must be 1-dim, not ", + seq_lens_shape.dims())); + OP_REQUIRES(context, batch_dim_ != seq_dim_, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); + OP_REQUIRES( + context, seq_dim_ < input_shape.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, batch_dim_ < input_shape.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, + seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, + "), ", "(", seq_lens_shape.num_elements(), + " vs. ", input_shape.dim_size(batch_dim_))); + + xla::ComputationBuilder* builder = context->builder(); + const auto input = context->Input(0); + const auto seq_lens = context->Input(1); + + const int64 batch_size = input_shape.dim_size(batch_dim_); + + const DataType input_type = context->input_type(0); + const DataType seq_lens_type = context->input_type(1); + const int64 max_seq_len = input_shape.dim_size(seq_dim_); + + xla::Shape input_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, + &input_xla_shape)); + xla::Shape seq_lens_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, + &seq_lens_xla_shape)); + + const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), + seq_lens_xla_shape, + input_xla_shape, + }); + + // For each entry in the batch, reverse the sequence. + // TODO(b/65689298): generalize the Map() operator to non-scalar cases and + // use it here, instead of a While loop. + + // Condition: lambda (i, _, _): i < batch_size + auto condition_builder = + builder->CreateSubBuilder("reverse_sequence_condition"); + { + auto param = condition_builder->Parameter(0, tuple_shape, "param"); + auto i = condition_builder->GetTupleElement(param, 0); + condition_builder->Lt( + i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, + batch_size)); + } + auto condition = condition_builder->Build(); + OP_REQUIRES_OK(context, condition.status()); + + auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); + { + auto param = body_builder->Parameter(0, tuple_shape, "param"); + auto i = body_builder->GetTupleElement(param, 0); + auto seq_lens = body_builder->GetTupleElement(param, 1); + auto output = body_builder->GetTupleElement(param, 2); + + // seq_len is the sequence length of the current batch element (rank 1) + auto seq_len = body_builder->DynamicSlice( + seq_lens, body_builder->Reshape(i, {1}), {1}); + + // Indices is the offset of the batch element in the input. + auto indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + indices = body_builder->DynamicUpdateSlice( + indices, body_builder->Reshape(i, {1}), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + batch_dim_), + {1})); + + // slice_indices is the offset of the start of the reversed sequence in + // the input. + auto slice_indices = body_builder->DynamicUpdateSlice( + indices, + body_builder->Sub(XlaHelpers::IntegerLiteral( + body_builder.get(), seq_lens_type, max_seq_len), + seq_len), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + seq_dim_), + {1})); + + // Slice out the reversed sequence. The slice will overflow the end of the + // sequence, and the contents of the overflow are implementation-defined. + // However, we will mask off these elements and replace them with elements + // from the original input so their values do not matter. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + auto slice = body_builder->DynamicSlice(output, slice_indices, + slice_shape.dim_sizes()); + + // Shift the reversed sequence to the left. + output = body_builder->DynamicUpdateSlice(output, slice, indices); + + body_builder->Tuple( + {body_builder->Add( + i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + seq_lens, output}); + } + auto body = body_builder->Build(); + OP_REQUIRES_OK(context, body.status()); + + auto loop_output = builder->While( + condition.ValueOrDie(), body.ValueOrDie(), + builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + builder->Rev(input, {seq_dim_})})); + auto output = builder->GetTupleElement(loop_output, 2); + + // Mask out elements after the sequence length. + xla::ComputationDataHandle iota; + OP_REQUIRES_OK( + context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + std::vector<int64> dims(input_shape.dims(), 1); + dims[batch_dim_] = batch_size; + auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + + // Broadcast the mask up to the input shape. + mask = + builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false), + input_shape.dim_sizes())); + + output = builder->Select(mask, output, input); + context->SetOutput(0, output); + } + + private: + int32 batch_dim_; + int32 seq_dim_; +}; + +REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); + +} // namespace +} // namespace tensorflow |