aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/diag_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/diag_op.cc')
-rw-r--r--tensorflow/core/kernels/diag_op.cc93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc
new file mode 100644
index 0000000000..83e39d33a9
--- /dev/null
+++ b/tensorflow/core/kernels/diag_op.cc
@@ -0,0 +1,93 @@
+// See docs in ../ops/array_ops.cc
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace {
+template <typename T, size_t NumDims, size_t DoubleNumDims>
+class DiagonalGenerator {
+ public:
+ explicit DiagonalGenerator(const Tensor& diagonal) : diagonal_(diagonal) {
+ static_assert(DoubleNumDims == 2 * NumDims,
+ "The second size must be the double of the first size.");
+ CHECK_EQ(diagonal.dims(), NumDims);
+ }
+ T operator()(
+ const Eigen::array<Eigen::DenseIndex, DoubleNumDims>& coordinates) const {
+ Eigen::array<Eigen::DenseIndex, NumDims> index;
+ for (int i = 0; i < NumDims; ++i) {
+ if (coordinates[i] != coordinates[NumDims + i]) {
+ return T(0);
+ }
+ index[i] = coordinates[i];
+ }
+ return diagonal_.tensor<T, NumDims>()(index);
+ }
+
+ private:
+ Tensor diagonal_;
+};
+} // namespace
+
+// Generate the diagonal tensor with the diagonal set to the input tensor.
+// It only allows up to rank 3 input tensor, so the output tensor is up to
+// rank 6.
+template <typename T>
+class DiagOp : public OpKernel {
+ public:
+ explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& diagonal = context->input(0);
+ const int num_dims = diagonal.dims();
+ OP_REQUIRES(context, 1 <= num_dims,
+ errors::InvalidArgument(
+ "The rank of the diagonal should be between 1 and 3."));
+ OP_REQUIRES(context, 3 >= num_dims,
+ errors::InvalidArgument(
+ "The rank of the diagonal should be between 1 and 3."));
+ TensorShape out_shape;
+ for (int i = 0; i < num_dims; ++i) {
+ out_shape.AddDim(diagonal.dim_size(i));
+ }
+ for (int i = 0; i < num_dims; ++i) {
+ out_shape.AddDim(diagonal.dim_size(i));
+ }
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, out_shape, &output_tensor));
+ switch (num_dims) {
+ case 1:
+ output_tensor->tensor<T, 2>() = output_tensor->tensor<T, 2>().generate(
+ DiagonalGenerator<T, 1, 2>(diagonal));
+ break;
+ case 2:
+ output_tensor->tensor<T, 4>() = output_tensor->tensor<T, 4>().generate(
+ DiagonalGenerator<T, 2, 4>(diagonal));
+ break;
+ case 3:
+ output_tensor->tensor<T, 6>() = output_tensor->tensor<T, 6>().generate(
+ DiagonalGenerator<T, 3, 6>(diagonal));
+ break;
+ default:
+ context->SetStatus(errors::Unimplemented(
+ "Diagonal of rank ", num_dims, " tensor is not supported yet."));
+ return;
+ }
+ }
+};
+
+#define REGISTER_DIAGOP(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), DiagOp<T>)
+
+REGISTER_DIAGOP(double);
+REGISTER_DIAGOP(float);
+REGISTER_DIAGOP(int32);
+REGISTER_DIAGOP(int64);
+
+#undef REGISTER_DIAGOP
+} // namespace tensorflow