diff options
Diffstat (limited to 'tensorflow/core/kernels/reverse_op.cc')
-rw-r--r-- | tensorflow/core/kernels/reverse_op.cc | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc new file mode 100644 index 0000000000..c63dfc1e70 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op.cc @@ -0,0 +1,139 @@ +// See docs in ../ops/array_ops.cc +#define EIGEN_USE_THREADS + +#include <memory> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/reverse_op.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class ReverseOp : public OpKernel { + public: + explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& dims = context->input(1); + + if (TensorShapeUtils::IsScalar(input.shape())) { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + output->scalar<T>() = input.scalar<T>(); + + } else { + const int input_dims = input.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()), + errors::InvalidArgument("'dims' must be 1-dimension, not ", + dims.dims())); + + OP_REQUIRES(context, input_dims == dims.dim_size(0), + errors::InvalidArgument( + "'dims' must have the same number of values as 'input' has " + "dimensions. 'input' has ", input_dims, "'dims' has ", + dims.dim_size(0), " values")); + OP_REQUIRES(context, input_dims <= 8, errors::Unimplemented( + "reverse is not implemented for tensors of rank > 8.")); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + +#define HANDLE_REVERSE(NDIMS) \ + case NDIMS: \ + functor::Reverse<Device, T, NDIMS>()( \ + context->eigen_device<Device>(), input.tensor<T, NDIMS>(), \ + dims.vec<bool>(), output->tensor<T, NDIMS>()); \ + return; + + switch (input_dims) { + HANDLE_REVERSE(0); + HANDLE_REVERSE(1); + HANDLE_REVERSE(2); + HANDLE_REVERSE(3); + HANDLE_REVERSE(4); + HANDLE_REVERSE(5); + HANDLE_REVERSE(6); + HANDLE_REVERSE(7); + HANDLE_REVERSE(8); + } +#undef HANDLE_REVERSE + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Reverse") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("dims"), \ + ReverseOp<CPUDevice, T>) + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(bool); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA + +// Forward declarations of the function specializations for GPU (to prevent +// building the GPU versions here, they will be built compiling _gpu.cu.cc). +namespace functor { +#define DECLARE_GPU_SPEC_DIM(T, DIM) \ + template <> \ + void Reverse<GPUDevice, T, DIM>::operator()( \ + const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \ + typename TTypes<bool, 1>::ConstTensor dims, \ + typename TTypes<T, DIM>::Tensor output); \ + extern template struct Reverse<GPUDevice, T, DIM>; +#define DECLARE_GPU_SPEC(T) \ + DECLARE_GPU_SPEC_DIM(T, 0) \ + DECLARE_GPU_SPEC_DIM(T, 1) \ + DECLARE_GPU_SPEC_DIM(T, 2) \ + DECLARE_GPU_SPEC_DIM(T, 3) \ + DECLARE_GPU_SPEC_DIM(T, 4) \ + DECLARE_GPU_SPEC_DIM(T, 5) \ + DECLARE_GPU_SPEC_DIM(T, 6) \ + DECLARE_GPU_SPEC_DIM(T, 7) \ + DECLARE_GPU_SPEC_DIM(T, 8) + +DECLARE_GPU_SPEC(uint8); +DECLARE_GPU_SPEC(int8); +DECLARE_GPU_SPEC(int32); +DECLARE_GPU_SPEC(bool); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPEC_DIM +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Reverse") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("dims"), \ + ReverseOp<GPUDevice, T>) +REGISTER_GPU_KERNEL(uint8); +REGISTER_GPU_KERNEL(int8); +REGISTER_GPU_KERNEL(float); +REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow |