diff options
Diffstat (limited to 'tensorflow/core/kernels/reverse_sequence_op.h')
-rw-r--r-- | tensorflow/core/kernels/reverse_sequence_op.h | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h new file mode 100644 index 0000000000..d1dd572dcb --- /dev/null +++ b/tensorflow/core/kernels/reverse_sequence_op.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ +#define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ +// Generator definition for ReverseSequenceOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace generator { + +template <typename T, size_t Dims> +class ReverseGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 seq_dim, + TTypes<int64>::ConstVec seq_lengths) + : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array<Eigen::DenseIndex, Dims>& coords) const { + Eigen::array<Eigen::DenseIndex, Dims> new_coords = coords; + if (coords[seq_dim_] < seq_lengths_(coords[0])) { + new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1; + } + + return input_(new_coords); + } + + private: + typename TTypes<T, Dims>::ConstTensor input_; + int32 seq_dim_; + TTypes<int64>::ConstVec seq_lengths_; +}; + +} // namespace generator + +namespace functor { + +template <typename Device, typename T, size_t Dims> +struct ReverseSequence { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, typename TTypes<T, Dims>::ConstTensor input, + int32 seq_dim, TTypes<int64>::ConstVec seq_lengths, + typename TTypes<T, Dims>::Tensor output) { + generator::ReverseGenerator<T, Dims> generator(input, seq_dim, seq_lengths); + output.device(d) = input.generate(generator); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ |