aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
blob: 81eead11d1aef456ddb6415f4279269fbe0f2f76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include <algorithm>
#include <numeric>
#include <unordered_map>
#include <utility>
#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {

using CPUDevice = Eigen::ThreadPoolDevice;

template <typename T>
class SparseFillEmptyRowsOp : public OpKernel {
 public:
  explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor* indices_t;
    const Tensor* values_t;
    const Tensor* dense_shape_t;
    const Tensor* default_value_t;
    OP_REQUIRES_OK(context, context->input("indices", &indices_t));
    OP_REQUIRES_OK(context, context->input("values", &values_t));
    OP_REQUIRES_OK(context, context->input("dense_shape", &dense_shape_t));
    OP_REQUIRES_OK(context, context->input("default_value", &default_value_t));

    const CPUDevice& d = context->eigen_device<CPUDevice>();

    OP_REQUIRES(context, TensorShapeUtils::IsVector(dense_shape_t->shape()),
                errors::InvalidArgument("dense_shape must be a vector, saw: ",
                                        dense_shape_t->shape().DebugString()));
    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t->shape()),
                errors::InvalidArgument("indices must be a matrix, saw: ",
                                        indices_t->shape().DebugString()));
    OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t->shape()),
                errors::InvalidArgument("values must be a vector, saw: ",
                                        values_t->shape().DebugString()));
    OP_REQUIRES(
        context, TensorShapeUtils::IsScalar(default_value_t->shape()),
        errors::InvalidArgument("default_value must be a scalar, saw: ",
                                default_value_t->shape().DebugString()));
    // TODO(ebrevdo): add shape checks between values, indices,
    // dense_shape.  Also add check that dense rank > 0.

    const T& default_value = default_value_t->scalar<T>()();
    const auto indices = indices_t->matrix<int64>();
    const auto values = values_t->vec<T>();
    const auto dense_shape = dense_shape_t->vec<int64>();

    const int64 N = indices_t->shape().dim_size(0);
    const int64 dense_rows = dense_shape(0);

    Tensor* empty_row_indicator_t;
    OP_REQUIRES_OK(context, context->allocate_output("empty_row_indicator",
                                                     TensorShape({dense_rows}),
                                                     &empty_row_indicator_t));
    auto empty_row_indicator = empty_row_indicator_t->vec<bool>();
    Tensor* reverse_index_map_t;
    OP_REQUIRES_OK(
        context, context->allocate_output("reverse_index_map", TensorShape({N}),
                                          &reverse_index_map_t));
    auto reverse_index_map = reverse_index_map_t->vec<int64>();

    int rank = indices_t->shape().dim_size(1);

    if (dense_rows == 0) {
      OP_REQUIRES(
          context, N == 0,
          errors::InvalidArgument("Received SparseTensor with dense_shape[0] = "
                                  "0 but indices.shape[0] = ",
                                  N));
      Tensor* output_indices_t;
      TensorShape output_indices_shape({0, rank});
      OP_REQUIRES_OK(context, context->allocate_output("output_indices",
                                                       output_indices_shape,
                                                       &output_indices_t));
      Tensor* output_values_t;
      OP_REQUIRES_OK(context,
                     context->allocate_output("output_values", TensorShape({0}),
                                              &output_values_t));

      // Exit early, nothing more to do.
      return;
    }

    Tensor scratch_t;
    OP_REQUIRES_OK(context,
                   context->allocate_temp(DT_INT64, TensorShape({dense_rows}),
                                          &scratch_t));
    auto scratch = scratch_t.vec<int64>();
    scratch.device(d) = scratch.constant(0);
    int64 prev_row = -1;
    for (int i = 0; i < N; ++i) {
      const int64 row = indices(i, 0);
      OP_REQUIRES(context, indices(i, 0) >= 0 && indices(i, 0) < dense_rows,
                  errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
                                          indices(i, 0), " >= ", dense_rows));
      prev_row = row;
      ++scratch(indices(i, 0));
    }
    for (int row = 0; row < dense_rows; ++row) {
      // Scratch here describes the number of elements in this dense row
      empty_row_indicator(row) = (scratch(row) == 0);
      // In filled version, each row has at least one element.
      scratch(row) = std::max(scratch(row), 1LL);
      // Update scratch to represent the number of elements up to and
      // including dense_row + 1:
      //  scratch(0) == #{elements of row 0}
      //  scratch(1) == #{elements of row 1} + #{elements of row 0}
      //  ..
      //  scratch(i) == starting index for elements in row i + 1.
      if (row > 0) {
        scratch(row) += scratch(row - 1);
      }
    }
    Tensor* output_indices_t;
    const int64 N_full = scratch(dense_rows - 1);
    TensorShape output_indices_shape({N_full, rank});
    OP_REQUIRES_OK(context, context->allocate_output("output_indices",
                                                     output_indices_shape,
                                                     &output_indices_t));
    auto output_indices = output_indices_t->matrix<int64>();
    output_indices.device(d) = output_indices.constant(0);

    Tensor* output_values_t;
    OP_REQUIRES_OK(
        context, context->allocate_output(
                     "output_values", TensorShape({N_full}), &output_values_t));
    auto output_values = output_values_t->vec<T>();
    output_values.device(d) = output_values.constant(default_value);

    Tensor filled_count_t;
    OP_REQUIRES_OK(context,
                   context->allocate_temp(DT_INT64, TensorShape({dense_rows}),
                                          &filled_count_t));
    auto filled_count = filled_count_t.vec<int64>();
    filled_count.device(d) = filled_count.constant(0);

    // Fill in values for rows that are not missing
    for (int64 i = 0; i < N; ++i) {
      const int64 row = indices(i, 0);
      int64& offset = filled_count(row);
      const int64 output_i = ((row == 0) ? 0 : scratch(row - 1)) + offset;
      offset++;  // Increment the filled count for this row.
      std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
      output_values(output_i) = values(i);
      // We'll need this reverse index map to backprop correctly.
      reverse_index_map(i) = output_i;
    }

    // Fill in values for rows that are missing
    for (int64 row = 0; row < dense_rows; ++row) {
      const int64 row_count = filled_count(row);
      if (row_count == 0) {  // We haven't filled this row
        const int64 starting_index = (row == 0) ? 0 : scratch(row - 1);
        // Remaining index values were set to zero already.
        // The value at this index was set to default_value already.
        // Just need to set the row index in the right location.
        output_indices(starting_index, 0) = row;
      }
    }
  }
};

#define REGISTER_KERNELS(type)                            \
  REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")     \
                              .Device(DEVICE_CPU)         \
                              .TypeConstraint<type>("T"), \
                          SparseFillEmptyRowsOp<type>)

TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

template <typename T>
class SparseFillEmptyRowsGradOp : public OpKernel {
 public:
  explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor* reverse_index_map_t;
    const Tensor* grad_values_t;
    OP_REQUIRES_OK(context,
                   context->input("reverse_index_map", &reverse_index_map_t));
    OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));

    const CPUDevice& d = context->eigen_device<CPUDevice>();

    OP_REQUIRES(
        context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
        errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
                                reverse_index_map_t->shape().DebugString()));

    const auto reverse_index_map = reverse_index_map_t->vec<int64>();
    const auto grad_values = grad_values_t->vec<T>();

    const int64 N = reverse_index_map_t->shape().dim_size(0);
    const int64 N_full = grad_values_t->shape().dim_size(0);

    Tensor* d_values_t;
    OP_REQUIRES_OK(context, context->allocate_output(
                                "d_values", TensorShape({N}), &d_values_t));
    auto d_values = d_values_t->vec<T>();
    Tensor* d_default_value_t;
    OP_REQUIRES_OK(context,
                   context->allocate_output("d_default_value", TensorShape({}),
                                            &d_default_value_t));
    T& d_default_value = d_default_value_t->scalar<T>()();
    d_default_value = T();

    Tensor visited_t;
    OP_REQUIRES_OK(context, context->allocate_temp(
                                DT_BOOL, TensorShape({N_full}), &visited_t));
    auto visited = visited_t.vec<bool>();
    visited.device(d) = visited.constant(false);

    for (int i = 0; i < N; ++i) {
      // Locate the index of the output of the forward prop associated
      // with this location in the input of the forward prop.  Copy
      // the gradient into it.  Mark it as visited.
      d_values(i) = grad_values(reverse_index_map(i));
      visited(reverse_index_map(i)) = true;
    }
    for (int j = 0; j < N_full; ++j) {
      // The default value gradient gets the accumulated remainder of
      // the backprop values (since the default value was used to fill
      // in these slots in the forward calculation).
      if (!visited(j)) {
        d_default_value += grad_values(j);
      }
    }
  }
};

#define REGISTER_KERNELS(type)                            \
  REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
                              .Device(DEVICE_CPU)         \
                              .TypeConstraint<type>("T"), \
                          SparseFillEmptyRowsGradOp<type>)

TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
}  // namespace tensorflow