aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unpack_op.cc
blob: 36cfb2c8e598ffebefaeba13353c9fe7a91cf42c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// See docs in ../ops/array_ops.cc.

#define EIGEN_USE_THREADS

#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/split_op.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

template <typename Device, typename T>
class UnpackOp : public OpKernel {
 public:
  explicit UnpackOp(OpKernelConstruction* c) : OpKernel(c) {}

  void Compute(OpKernelContext* context) override {
    const int32 num = num_outputs();
    const Tensor& input = context->input(0);
    const TensorShape& input_shape = input.shape();

    OP_REQUIRES(
        context, input_shape.dims() > 0 && input_shape.dim_size(0) == num,
        errors::InvalidArgument("Input shape must start with ", num, ", got ",
                                input_shape.ShortDebugString()));

    auto output_shape = input_shape;
    output_shape.RemoveDim(0);
    const int32 output_size = output_shape.num_elements();

    // Special case: Aligned, so we can share the underlying buffer.
    //
    // Apply this optimization conservatively: if input is aligned,
    // the resulting tensors must be aligned. It's conservative
    // because if the immediate consumer of the resulting tensors are
    // not using eigen for computation, its perfectly fine to avoid
    // the copying.
    if (output_size == 0 || IsInnerDimsSizeAligned<T>(input_shape)) {
      for (int i = 0; i < num; ++i) {
        Tensor output;
        CHECK(output.CopyFrom(input.Slice(i, i + 1), output_shape));
        context->set_output(i, output);
      }
      return;
    }

    // Except for shape, unpack is a special case of split, so we reuse the
    // same computational kernels.
    auto input_reshaped = input.shaped<T, 3>({1, num, output_size});

    for (int i = 0; i < num; ++i) {
      Tensor* output;
      OP_REQUIRES_OK(context,
                     context->allocate_output(i, output_shape, &output));
      auto output_shaped = output->shaped<T, 3>({1, 1, output_size});

      Eigen::DSizes<ptrdiff_t, 3> indices{0, i, 0};
      Eigen::DSizes<ptrdiff_t, 3> sizes{1, 1, output_size};
      functor::Split<Device, T>()(context->eigen_device<Device>(),
                                  output_shaped, input_reshaped, indices,
                                  sizes);
    }
  }
};

#define REGISTER_UNPACK(type)                                      \
  REGISTER_KERNEL_BUILDER(                                         \
      Name("Unpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      UnpackOp<CPUDevice, type>)

TF_CALL_ALL_TYPES(REGISTER_UNPACK);

#undef REGISTER_UNPACK

#if GOOGLE_CUDA

#define REGISTER_GPU(type)                                         \
  REGISTER_KERNEL_BUILDER(                                         \
      Name("Unpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
      UnpackOp<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
#undef REGISTER_GPU

#endif  // GOOGLE_CUDA

}  // end namespace tensorflow