blob: bba25f70e8c2e038bad3e90a256a729ae4bbb2b5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
#ifndef TENSORFLOW_KERNELS_REVERSE_OP_H_
#define TENSORFLOW_KERNELS_REVERSE_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace functor {
// Functor used by MirrorOp to do the computations.
template <typename Device, typename T, int Dims>
struct Reverse {
void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
typename TTypes<bool, 1>::ConstTensor dims,
typename TTypes<T, Dims>::Tensor output) {
// mirror is in host memory
Eigen::array<bool, Dims> reverse_dims;
for (int i = 0; i < Dims; ++i) {
reverse_dims[i] = dims(i);
}
output.device(d) = input.reverse(reverse_dims);
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_MIRROR_OP_H_
|