aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/diag_op.cc
blob: 83e39d33a9f281be76c6bbdc1c4a41f02be617cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// See docs in ../ops/array_ops.cc
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace {
template <typename T, size_t NumDims, size_t DoubleNumDims>
class DiagonalGenerator {
 public:
  explicit DiagonalGenerator(const Tensor& diagonal) : diagonal_(diagonal) {
    static_assert(DoubleNumDims == 2 * NumDims,
                  "The second size must be the double of the first size.");
    CHECK_EQ(diagonal.dims(), NumDims);
  }
  T operator()(
      const Eigen::array<Eigen::DenseIndex, DoubleNumDims>& coordinates) const {
    Eigen::array<Eigen::DenseIndex, NumDims> index;
    for (int i = 0; i < NumDims; ++i) {
      if (coordinates[i] != coordinates[NumDims + i]) {
        return T(0);
      }
      index[i] = coordinates[i];
    }
    return diagonal_.tensor<T, NumDims>()(index);
  }

 private:
  Tensor diagonal_;
};
}  // namespace

// Generate the diagonal tensor with the diagonal set to the input tensor.
// It only allows up to rank 3 input tensor, so the output tensor is up to
// rank 6.
template <typename T>
class DiagOp : public OpKernel {
 public:
  explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor& diagonal = context->input(0);
    const int num_dims = diagonal.dims();
    OP_REQUIRES(context, 1 <= num_dims,
                errors::InvalidArgument(
                    "The rank of the diagonal should be between 1 and 3."));
    OP_REQUIRES(context, 3 >= num_dims,
                errors::InvalidArgument(
                    "The rank of the diagonal  should be between 1 and 3."));
    TensorShape out_shape;
    for (int i = 0; i < num_dims; ++i) {
      out_shape.AddDim(diagonal.dim_size(i));
    }
    for (int i = 0; i < num_dims; ++i) {
      out_shape.AddDim(diagonal.dim_size(i));
    }
    Tensor* output_tensor = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, out_shape, &output_tensor));
    switch (num_dims) {
      case 1:
        output_tensor->tensor<T, 2>() = output_tensor->tensor<T, 2>().generate(
            DiagonalGenerator<T, 1, 2>(diagonal));
        break;
      case 2:
        output_tensor->tensor<T, 4>() = output_tensor->tensor<T, 4>().generate(
            DiagonalGenerator<T, 2, 4>(diagonal));
        break;
      case 3:
        output_tensor->tensor<T, 6>() = output_tensor->tensor<T, 6>().generate(
            DiagonalGenerator<T, 3, 6>(diagonal));
        break;
      default:
        context->SetStatus(errors::Unimplemented(
            "Diagonal of rank ", num_dims, " tensor is not supported yet."));
        return;
    }
  }
};

#define REGISTER_DIAGOP(T) \
  REGISTER_KERNEL_BUILDER( \
      Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), DiagOp<T>)

REGISTER_DIAGOP(double);
REGISTER_DIAGOP(float);
REGISTER_DIAGOP(int32);
REGISTER_DIAGOP(int64);

#undef REGISTER_DIAGOP
}  // namespace tensorflow