aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_to_dense_op.cc
blob: 47e91c134d5a88d2ab968e72a59c8c06140673dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// See core/ops/sparse_ops.cc for documentation.
//
// NOTE: the operations in this file only are suitable for execution
// on CPUs.

#define EIGEN_USE_THREADS

#include <string>
#include <sstream>
#include <unordered_map>
#include <utility>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {

// Operator to convert sparse representations to dense.
template <typename T, typename Index>
class SparseToDense : public OpKernel {
 public:
  explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* c) override {
    // sparse_indices
    const Tensor& indices = c->input(0);
    OP_REQUIRES(c, indices.dims() <= 2,
                errors::InvalidArgument(
                    "sparse_indices should be a scalar, vector, or matrix, "
                    "got shape ",
                    indices.shape().ShortDebugString()));
    const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
    const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;

    // output_shape
    const Tensor& output_shape = c->input(1);
    OP_REQUIRES(
        c, TensorShapeUtils::IsLegacyVector(output_shape.shape()),
        errors::InvalidArgument("output_shape should be a vector, got shape ",
                                output_shape.shape().ShortDebugString()));
    OP_REQUIRES(c, output_shape.NumElements() == num_dims,
                errors::InvalidArgument(
                    "output_shape has incorrect number of elements: ",
                    output_shape.NumElements(), " should be: ", num_dims));

    // sparse_values
    const Tensor& sparse_values = c->input(2);
    const int64 num_values = sparse_values.NumElements();
    OP_REQUIRES(
        c, sparse_values.dims() == 0 ||
               (sparse_values.dims() == 1 && num_values == num_elems),
        errors::InvalidArgument("sparse_values has incorrect shape ",
                                sparse_values.shape().ShortDebugString(),
                                ", should be [] or [", num_elems, "]"));

    // default_value
    const Tensor& default_value = c->input(3);
    OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()),
                errors::InvalidArgument("default_value should be a scalar."));

    auto output_shape_vec = output_shape.flat<Index>();
    Tensor* output = nullptr;
    OP_REQUIRES_OK(c, c->allocate_output(0, TensorShapeUtils::MakeShape(
                                                output_shape_vec.data(),
                                                output_shape_vec.size()),
                                         &output));

    TensorShape ix_shape({num_elems, num_dims});
    Tensor indices_shaped(DT_INT64, ix_shape);
    if (indices.dtype() == DT_INT64) {
      CHECK(indices_shaped.CopyFrom(indices, ix_shape));
    } else {
      indices_shaped.matrix<int64>() =
          indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>();
    }

    // If we received a scalar, we'll need to create a new
    // tensor with copies of the values as a vec.
    // TODO(ebrevdo): find a way to avoid this temp allocation.
    Tensor sparse_values_b;

    if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
      OP_REQUIRES_OK(
          c, c->allocate_temp(DataTypeToEnum<T>::value,
                              TensorShape({num_elems}), &sparse_values_b));
      sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()());
    } else {
      sparse_values_b = sparse_values;
    }

    gtl::InlinedVector<int64, 8> order(output->shape().dims());
    std::iota(order.begin(), order.end(), 0);  // Assume order is correct
    sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(),
                            order);

    output->flat<T>().setConstant(default_value.scalar<T>()());
    OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
                errors::InvalidArgument(
                    "Indices are not valid (out of bounds).  Shape: ",
                    output->shape().DebugString()));
  }
};

#define REGISTER_KERNELS(type, index_type)                             \
  REGISTER_KERNEL_BUILDER(Name("SparseToDense")                        \
                              .Device(DEVICE_CPU)                      \
                              .TypeConstraint<type>("T")               \
                              .TypeConstraint<index_type>("Tindices"), \
                          SparseToDense<type, index_type>);

#define REGISTER_KERNELS_ALL(type) \
  REGISTER_KERNELS(type, int32);   \
  REGISTER_KERNELS(type, int64);

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL);
REGISTER_KERNELS_ALL(bool);
REGISTER_KERNELS_ALL(string);

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS

}  // namespace tensorflow