aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reverse_sequence_op.cc
blob: 6673a700efb50903d65b85f0f9706d514043cbfe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// See docs in ../ops/array_ops.cc.

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif  // GOOGLE_CUDA

#include "tensorflow/core/kernels/reverse_sequence_op.h"

#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

template <typename Device>
void CheckErrors(OpKernelContext* context, int seq_dim) {
  const Tensor& input = context->input(0);
  const Tensor& seq_lens = context->input(1);

  auto seq_lens_t = seq_lens.vec<int64>();

  std::vector<int64> seq_lens_vec(seq_lens_t.size());

  // Copy seq_len info down for validity checks
  context->eigen_device<Device>().memcpyDeviceToHost(
      seq_lens_vec.data(), seq_lens_t.data(),
      sizeof(int64) * seq_lens_t.size());

  OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
  OP_REQUIRES(context, seq_dim < input.dims(),
              errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
                                      seq_dim, " vs. ", input.dims(), ")"));

  OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
              errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
                                      "(", seq_lens.NumElements(), " vs. ",
                                      input.dim_size(seq_dim)));

  for (int d = 0; d < seq_lens_vec.size(); ++d) {
    OP_REQUIRES(context, seq_lens_vec[d] >= 0,
                errors::InvalidArgument("seq_lens(", d, ") < 0"));
    OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
                errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
                                        seq_dim, ")"));
  }
}

template <>
void CheckErrors<GPUDevice>(OpKernelContext* context, int seq_dim) {
  const Tensor& input = context->input(0);
  const Tensor& seq_lens = context->input(1);

  OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
  OP_REQUIRES(context, seq_dim < input.dims(),
              errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
                                      seq_dim, " vs. ", input.dims(), ")"));

  OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
              errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
                                      "(", seq_lens.NumElements(), " vs. ",
                                      input.dim_size(seq_dim)));
}

template <typename Device, typename T>
class ReverseSequenceOp : public OpKernel {
 public:
  explicit ReverseSequenceOp(OpKernelConstruction* context)
      : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(0);
    const Tensor& seq_lens = context->input(1);

    // Preliminary validation of sizes.
    OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
                errors::InvalidArgument("seq_lens input must be 1-dim, not ",
                                        seq_lens.dims()));

    auto seq_lens_t = seq_lens.vec<int64>();

    CheckErrors<Device>(context, seq_dim_);

    const int input_dims = input.dims();

    Tensor* output = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input.shape(), &output));

#define HANDLE_DIM(NDIM)                                                    \
  case NDIM:                                                                \
    functor::ReverseSequence<Device, T, NDIM>::Compute(                     \
        context->eigen_device<Device>(), input.tensor<T, NDIM>(), seq_dim_, \
        seq_lens_t, output->tensor<T, NDIM>());                             \
    break;

    switch (input_dims) {
      HANDLE_DIM(2);
      HANDLE_DIM(3);
      HANDLE_DIM(4);
      HANDLE_DIM(5);

      default:
        OP_REQUIRES(context, false,
                    errors::InvalidArgument(
                        "ReverseSequenceOp : Unhandled input dimensions: ",
                        input_dims));
    }
  }

 private:
  int32 seq_dim_;

  TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
};

#define REGISTER_REVERSE_SEQUENCE(type)                                     \
  REGISTER_KERNEL_BUILDER(                                                  \
      Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ReverseSequenceOp<CPUDevice, type>);

TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE);

#if GOOGLE_CUDA

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Dims)                                      \
  template <>                                                          \
  void ReverseSequence<GPUDevice, T, Dims>::Compute(                   \
      const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
      int32 seq_dim, TTypes<int64>::ConstVec seq_lens,                 \
      typename TTypes<T, Dims>::Tensor output);                        \
  extern template struct ReverseSequence<GPUDevice, T, Dims>;

#define DECLARE_GPU_SPECS(T) \
  DECLARE_GPU_SPEC(T, 2);    \
  DECLARE_GPU_SPEC(T, 3);    \
  DECLARE_GPU_SPEC(T, 4);    \
  DECLARE_GPU_SPEC(T, 5);

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);

}  // namespace functor

// Registration of the GPU implementations.
#define REGISTER_REVERSE_SEQUENCE_GPU(type)                                 \
  REGISTER_KERNEL_BUILDER(                                                  \
      Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
      ReverseSequenceOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU);

#undef REGISTER_REVERSE_SEQUENCE_GPU

#endif  // GOOGLE_CUDA

}  // namespace tensorflow