#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ #define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ // Generator definition for ReverseSequenceOp, must be compilable by nvcc. #include "tensorflow/core/platform/port.h" #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace generator { template class ReverseGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ReverseGenerator(typename TTypes::ConstTensor input, int32 seq_dim, TTypes::ConstVec seq_lengths) : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const Eigen::array& coords) const { Eigen::array new_coords = coords; if (coords[seq_dim_] < seq_lengths_(coords[0])) { new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1; } return input_(new_coords); } private: typename TTypes::ConstTensor input_; int32 seq_dim_; TTypes::ConstVec seq_lengths_; }; } // namespace generator namespace functor { template struct ReverseSequence { EIGEN_ALWAYS_INLINE static void Compute( const Device& d, typename TTypes::ConstTensor input, int32 seq_dim, TTypes::ConstVec seq_lengths, typename TTypes::Tensor output) { generator::ReverseGenerator generator(input, seq_dim, seq_lengths); output.device(d) = input.generate(generator); } }; } // namespace functor } // namespace tensorflow #endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_