aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/deserialize_sparse_variant_op.cc
blob: fce3029e4e2457331fe73f3b4751aadbe273baf6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"

namespace tensorflow {

namespace {

class DeserializeSparseOp : public OpKernel {
 public:
  explicit DeserializeSparseOp(OpKernelConstruction* context)
      : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(0);

    OP_REQUIRES(
        context, input.dims() > 0,
        errors::InvalidArgument("Serialized sparse should have non-zero rank ",
                                input.shape().DebugString()));
    OP_REQUIRES(context, input.shape().dim_size(input.dims() - 1) == 3,
                errors::InvalidArgument(
                    "Serialized sparse should have 3 as the last dimension ",
                    input.shape().DebugString()));

    // `input_dims_to_stack` is the number of dimensions that will be added to
    // each of the elements before they are concatenated into the output.
    const int64 input_dims_to_stack = input.dims() - 1;
    int num_sparse_tensors = 1;
    for (int i = 0; i < input_dims_to_stack; ++i) {
      num_sparse_tensors *= input.shape().dim_size(i);
    }

    if (num_sparse_tensors == 1 && input_dims_to_stack == 0) {
      // Special case with a single sparse tensor, and no dimensions to add
      // to the output indices. We can return the boxed tensors directly (after
      // validating them).
      const Tensor* output_indices;
      const Tensor* output_values;
      const Tensor* output_shape;
      const auto& input_as_vec = input.vec<Variant>();
      int64 total_non_zeros;
      OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
                                  input_as_vec(1), input_as_vec(2), 0,
                                  &output_shape, &total_non_zeros));
      OP_REQUIRES_OK(context, GetAndValidateSparseTensorIndicesAndValues(
                                  input_as_vec(0), input_as_vec(1), 0,
                                  output_shape->NumElements(), &output_indices,
                                  &output_values));
      context->set_output(0, *output_indices);
      context->set_output(1, *output_values);
      context->set_output(2, *output_shape);
      return;
    }

    OP_REQUIRES(
        context, num_sparse_tensors > 0,
        errors::InvalidArgument(
            "Serialized sparse should have at least 1 serialized tensor, "
            "but has a zero dimension ",
            input.shape().DebugString()));

    const auto& input_as_matrix = input.flat_inner_dims<Variant, 2>();

    // Compute the output "dense shape" of and number of non-zero elements in
    // the stacked sparse tensors. Given an input of shape (S_0, ...,
    // S_{input_dims_to_stack-1}, 3), and an element of dense shape (E_0, ...
    // E_n), the output dense shape will be (S_0, ...,
    // S_{input_dims_to_stack-1}, E_0, ..., E_n).
    Tensor* output_shape;
    int64 total_non_zeros = 0;

    // Allocate and build the initial output shape based on the element shape of
    // the 0th sparse tensor in the input.
    //
    // NOTE(mrry): We define `element_shape` as a `const Tensor*` rather than a
    // `Tensor` to avoid the overhead of allocating and deallocating a `Tensor`
    // on the stack. While the per-`Tensor` cost is small, this op can unbox a
    // large number of tensors (3 per batch element) and these fixed overheads
    // dominate when the number of non-zeros per element is small.
    const Tensor* element_shape;
    OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
                                input_as_matrix(0, 1), input_as_matrix(0, 2), 0,
                                &element_shape, &total_non_zeros));
    OP_REQUIRES_OK(context,
                   context->allocate_output(
                       2, {input_dims_to_stack + element_shape->NumElements()},
                       &output_shape));
    const auto element_shape_vec = element_shape->vec<int64>();
    auto output_shape_vec = output_shape->vec<int64>();
    output_shape_vec(0) = num_sparse_tensors;
    for (int64 j = 0; j < input_dims_to_stack; ++j) {
      output_shape_vec(j) = input.dim_size(j);
    }
    for (int64 j = 0; j < element_shape->NumElements(); ++j) {
      output_shape_vec(j + input_dims_to_stack) = element_shape_vec(j);
    }

    // Accumulate the number of non-zero elements from the remaining sparse
    // tensors, and validate that they have compatible dense shapes.
    //
    // NOTE(mrry): For compatibility with the implementations of
    // DeserializeManySparse, and many ops that generate SparseTensors to batch
    // that do not have a fixed dense_shape (e.g. `tf.parse_single_example()`),
    // we compute the maximum in each dimension to find the smallest dense_shape
    // that bounds all of the input SparseTensors.
    for (int i = 1; i < num_sparse_tensors; ++i) {
      int64 num_non_zeros;
      OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
                                  input_as_matrix(i, 1), input_as_matrix(i, 2),
                                  i, &element_shape, &num_non_zeros));
      total_non_zeros += num_non_zeros;
      OP_REQUIRES(
          context,
          output_shape->NumElements() - input_dims_to_stack ==
              element_shape->NumElements(),
          errors::InvalidArgument(
              "Inconsistent shape across SparseTensors: rank prior to "
              "SparseTensor[",
              i, "] was: ", output_shape->NumElements() - input_dims_to_stack,
              " but rank of SparseTensor[", i,
              "] is: ", element_shape->NumElements()));
      const auto element_shape_vec = element_shape->vec<int64>();
      for (int j = 0; j < element_shape->NumElements(); ++j) {
        output_shape_vec(j + input_dims_to_stack) = std::max(
            output_shape_vec(j + input_dims_to_stack), element_shape_vec(j));
      }
    }

    // Compute the output "indices" matrix and "values" vector.
    Tensor* output_indices;
    Tensor* output_values;

    const int output_rank = output_shape->NumElements();
    OP_REQUIRES_OK(context,
                   context->allocate_output(
                       0, {static_cast<int64>(total_non_zeros), output_rank},
                       &output_indices));
    OP_REQUIRES_OK(
        context, context->allocate_output(
                     1, {static_cast<int64>(total_non_zeros)}, &output_values));

    // The bulk of the work in this method involves building the output indices
    // in a tight loop. For cache friendliness, we generate the indices in the
    // order that they will be laid out in memory. We use raw pointers instead
    // of Eigen element/slice indexing methods, to access the underlying index
    // buffer to minimize the amount of work in that tight loop.
    int64* output_indices_data = output_indices->matrix<int64>().data();
    size_t current_row = 0;

    for (int i = 0; i < num_sparse_tensors; ++i) {
      const Tensor* element_indices;
      const Tensor* element_values;
      OP_REQUIRES_OK(context, this->GetAndValidateSparseTensorIndicesAndValues(
                                  input_as_matrix(i, 0), input_as_matrix(i, 1),
                                  i, output_rank - input_dims_to_stack,
                                  &element_indices, &element_values));

      const size_t num_index_rows = element_values->NumElements();

      // An empty sparse tensor in the input will generate no data
      // in the output. We short-circuit the rest of the iteration to avoid
      // triggering assertions in the Eigen when manipulating empty tensors (or
      // slices of tensors).
      if (num_index_rows == 0) continue;

      const size_t start_row = current_row;
      const size_t next_start_row = current_row + num_index_rows;

      // NOTE(mrry): If the element is a scalar SparseTensor,
      // `element_indices` will be an empty tensor, and this pointer will not
      // be valid. However, we will not dereference the pointer in that case,
      // because `input_dims_to_stack == output_rank`.
      const int64* element_indices_data =
          element_indices->matrix<int64>().data();

      // Build the submatrix of `output_indices` for the i^th sparse tensor
      // in the input.
      //
      // Each row of `output_indices` comprises `input_dims_to_stack` indices
      // based on the position of the i^th sparse tensor in the input tensor,
      // followed by the indices from the corresponding row in
      // `element_indices`.
      if (input_dims_to_stack == 1 && output_rank == 2) {
        // We specialize this case because the compiler can generate
        // more efficient code when the number of indices for each element is
        // known statically. Since the most common use of this op is to
        // serialize batches of SparseTensors, and the most common source of
        // SparseTensors is the `tf.parse_single_example()` op, which generates
        // 1-D SparseTensors, we statically unroll the loop for the rank 2
        // output case.
        for (; current_row < next_start_row; ++current_row) {
          *output_indices_data++ = i;
          *output_indices_data++ = *element_indices_data++;
        }
      } else {
        // `sparse_tensor_index` is the tuple of indices that correspond to
        // mapping the flat element index (`i`) back onto the stacked
        // coordinates implied by the position of the i^th sparse tensor in the
        // input tensor.
        //
        // We build `sparse_tensor_index` in reverse (innermost/minor dimension
        // to outermost/major dimension). The `cumulative_product` represents
        // the size of the inner subtensor for which `sparse_tensor_index` has
        // already been built.
        gtl::InlinedVector<int64, 4> sparse_tensor_index(input_dims_to_stack);
        int cumulative_product = 1;
        for (size_t j = 0; j < sparse_tensor_index.size(); ++j) {
          size_t reverse_index = sparse_tensor_index.size() - j - 1;
          sparse_tensor_index[reverse_index] =
              (i / cumulative_product) % input.dim_size(reverse_index);
          cumulative_product *= input.dim_size(reverse_index);
        }
        for (; current_row < next_start_row; ++current_row) {
          for (int64 sparse_tensor_index_component : sparse_tensor_index) {
            *output_indices_data++ = sparse_tensor_index_component;
          }
          for (size_t k = input_dims_to_stack; k < output_rank; ++k) {
            *output_indices_data++ = *element_indices_data++;
          }
        }
      }

      // Build the subvector of `output_values` for the i^th sparse tensor
      // in the input.
      //
      // NOTE(mrry): There is a potential optimization here where we use a T*
      // to represent the current position in `output_values`, but it would
      // require some rejigging of the template parameters.
      // NOTE(mrry): Another potential optimization: if we know that this
      // operation consumes its input, we could std::move non-primitive elements
      // into the output and avoid a copy.
      Eigen::DSizes<Eigen::DenseIndex, 1> values_start(start_row);
      Eigen::DSizes<Eigen::DenseIndex, 1> values_sizes(num_index_rows);

#define HANDLE_TYPE(T)                                          \
  case DataTypeToEnum<T>::value: {                              \
    output_values->vec<T>().slice(values_start, values_sizes) = \
        element_values->vec<T>();                               \
    break;                                                      \
  }
      switch (dtype_) {
        TF_CALL_ALL_TYPES(HANDLE_TYPE);
        TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
#undef HANDLE_TYPE
        default:
          OP_REQUIRES_OK(
              context, errors::Unimplemented(
                           "DeserializeSparse Unhandled data type: ", dtype_));
      }
    }
  }

 private:
  Status GetAndValidateSparseTensorShape(const Variant& serialized_values,
                                         const Variant& serialized_shape,
                                         int index, const Tensor** output_shape,
                                         int64* output_num_non_zeros) {
    // Deserialize and validate the shape.
    *output_shape = serialized_shape.get<Tensor>();
    if (*output_shape == nullptr) {
      return errors::InvalidArgument(
          "Could not get a tensor from serialized_sparse[", index, ", 2]");
    }
    if ((*output_shape)->dtype() != DT_INT64) {
      return errors::InvalidArgument(
          "Expected serialized_sparse[", index,
          ", 2] to be a vector of DT_INT64 but received dtype ",
          DataTypeString((*output_shape)->dtype()));
    }
    if (!TensorShapeUtils::IsVector((*output_shape)->shape())) {
      return errors::InvalidArgument(
          "Expected serialized_sparse[", index,
          ", 2] to be a shape vector but its shape is ",
          (*output_shape)->shape().DebugString());
    }
    *output_num_non_zeros = serialized_values.get<Tensor>()->NumElements();
    return Status::OK();
  }

  Status GetAndValidateSparseTensorIndicesAndValues(
      const Variant& serialized_indices, const Variant& serialized_values,
      int index, int expected_rank, const Tensor** output_indices,
      const Tensor** output_values) {
    // Deserialize and validate the indices.
    *output_indices = serialized_indices.get<Tensor>();
    if (*output_indices == nullptr) {
      return errors::InvalidArgument(
          "Could not get a tensor from serialized_sparse[", index, ", 0]");
    }
    if ((*output_indices)->dtype() != DT_INT64) {
      return errors::InvalidArgument(
          "Expected serialized_sparse[", index,
          ", 0] to be a matrix of DT_INT64 but received dtype ",
          DataTypeString((*output_indices)->dtype()));
    }
    if (!TensorShapeUtils::IsMatrix((*output_indices)->shape())) {
      return errors::InvalidArgument(
          "Expected serialized_sparse[", index,
          ", 0] to represent an index matrix but received shape ",
          (*output_indices)->shape().DebugString());
    }
    int64 num_entries = (*output_indices)->dim_size(0);
    int rank = (*output_indices)->dim_size(1);
    if (rank != expected_rank) {
      return errors::InvalidArgument(
          "Expected column counts of SparseTensor[", index,
          "].indices to match size of SparseTensor[", index,
          "].shape but they do not: ", rank, " vs. ", expected_rank);
    }

    // Deserialize and validate the values.
    *output_values = serialized_values.get<Tensor>();
    if (*output_values == nullptr) {
      return errors::InvalidArgument(
          "Could not get a tensor from serialized_sparse[", index, ", 1]");
    }
    if (!TensorShapeUtils::IsVector((*output_values)->shape())) {
      return errors::InvalidArgument(
          "Expected serialized_sparse[", index,
          ", 1] to represent a values vector but received shape ",
          (*output_values)->shape().DebugString());
    }
    if (dtype_ != (*output_values)->dtype()) {
      return errors::InvalidArgument(
          "Requested SparseTensor of type ", DataTypeString(dtype_),
          " but SparseTensor[", index,
          "].values.dtype() == ", DataTypeString((*output_values)->dtype()));
    }
    if (num_entries != (*output_values)->dim_size(0)) {
      return errors::InvalidArgument(
          "Expected row counts of SparseTensor[", index,
          "].indices and SparseTensor[", index,
          "].values to match but they do not: ", num_entries, " vs. ",
          (*output_values)->dim_size(0));
    }

    return Status::OK();
  }

  DataType dtype_;
};

REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
                            .Device(DEVICE_CPU)
                            .TypeConstraint<Variant>("Tserialized"),
                        DeserializeSparseOp)

}  // namespace

}  // namespace tensorflow