aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_tensors_map_ops.cc
blob: 74fa3a15f06fdb267dc9776ee8a0903f8f6626de (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include <algorithm>
#include <numeric>
#include <unordered_map>
#include <utility>
#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

using sparse::SparseTensor;

class SparseTensorsMap : public ResourceBase {
 public:
  explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}

  string DebugString() override { return "A SparseTensorsMap"; }

  typedef struct {
    PersistentTensor indices;
    PersistentTensor values;
    gtl::InlinedVector<int64, 8> shape;
  } PersistentSparseTensor;

  Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
                         int64* handle) {
    PersistentTensor persistent_ix;
    Tensor* ix;
    TF_RETURN_IF_ERROR(ctx->allocate_persistent(
        sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix));
    *ix = sp.indices();

    PersistentTensor persistent_values;
    Tensor* values;
    TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(),
                                                sp.indices().shape(),
                                                &persistent_values, &values));
    *values = sp.values();
    {
      mutex_lock l(mu_);
      int64 unique_st_handle = counter_++;  // increment is guarded on purpose
      sp_tensors_[unique_st_handle] = PersistentSparseTensor{
          persistent_ix, persistent_values,
          gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())};
      *handle = unique_st_handle;
    }
    return Status::OK();
  }

  Status RetrieveAndClearSparseTensors(
      OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles,
      std::vector<SparseTensor>* sparse_tensors) {
    sparse_tensors->clear();
    sparse_tensors->reserve(handles.size());
    {
      mutex_lock l(mu_);
      for (size_t i = 0; i < handles.size(); ++i) {
        const int64 handle = handles(i);
        auto sp_iter = sp_tensors_.find(handle);
        if (sp_iter == sp_tensors_.end()) {
          return errors::InvalidArgument(
              "Unable to find SparseTensor: ", handle, " in map: ", name_);
        }
        const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx);
        const Tensor* values = sp_iter->second.values.AccessTensor(ctx);
        const auto& shape = sp_iter->second.shape;
        SparseTensor tensor;
        TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
        sparse_tensors->push_back(std::move(tensor));
        sp_tensors_.erase(sp_iter);
      }
    }

    return Status::OK();
  }

 protected:
  ~SparseTensorsMap() override {}

 private:
  string name_;

  mutex mu_;
  int64 counter_ GUARDED_BY(mu_);
  std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ GUARDED_BY(mu_);
};

class SparseTensorAccessingOp : public OpKernel {
 public:
  typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;

  explicit SparseTensorAccessingOp(OpKernelConstruction* context)
      : OpKernel(context), sparse_tensors_map_(nullptr) {}

 protected:
  ~SparseTensorAccessingOp() override {
    if (sparse_tensors_map_) sparse_tensors_map_->Unref();
  }

  Status GetMap(OpKernelContext* ctx, bool is_writing,
                SparseTensorsMap** sparse_tensors_map) {
    mutex_lock l(mu_);

    if (sparse_tensors_map_) {
      *sparse_tensors_map = sparse_tensors_map_;
      return Status::OK();
    }

    TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
                                   is_writing /* use_node_name_as_default */));

    CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
      SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
      *c = map;
      return Status::OK();
    };

    TF_RETURN_IF_ERROR(
        cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
            cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
            sparse_tensors_map_creator));

    *sparse_tensors_map = sparse_tensors_map_;
    return Status::OK();
  }

 private:
  ContainerInfo cinfo_;

  mutex mu_;
  SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_);
};

class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
 public:
  explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
      : SparseTensorAccessingOp(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor* input_indices;
    const Tensor* input_values;
    const Tensor* input_shape;
    SparseTensorsMap* map;

    OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
    OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
    OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
    OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));

    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
                errors::InvalidArgument(
                    "Input indices should be a matrix but received shape ",
                    input_indices->shape().DebugString()));

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
                errors::InvalidArgument(
                    "Input values should be a vector but received shape ",
                    input_values->shape().DebugString()));

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
                errors::InvalidArgument(
                    "Input shape should be a vector but received shape ",
                    input_shape->shape().DebugString()));

    TensorShape input_shape_object;
    OP_REQUIRES_OK(context,
                   TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(),
                                               input_shape->NumElements(),
                                               &input_shape_object));
    SparseTensor st;
    OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
                                                 input_shape_object, &st));
    int64 handle;
    OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));

    Tensor sparse_handle(DT_INT64, TensorShape({}));
    auto sparse_handle_t = sparse_handle.scalar<int64>();

    sparse_handle_t() = handle;

    context->set_output(0, sparse_handle);
  }
};

REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
                        AddSparseToTensorsMapOp);

template <typename T>
class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
 public:
  explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
      : SparseTensorAccessingOp(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor* input_indices;
    const Tensor* input_values;
    const Tensor* input_shape;
    SparseTensorsMap* map;

    OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
    OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
    OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
    OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));

    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
                errors::InvalidArgument(
                    "Input indices should be a matrix but received shape ",
                    input_indices->shape().DebugString()));

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
                errors::InvalidArgument(
                    "Input values should be a vector but received shape ",
                    input_values->shape().DebugString()));

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
                errors::InvalidArgument(
                    "Input shape should be a vector but received shape ",
                    input_shape->shape().DebugString()));

    int rank = input_shape->NumElements();

    OP_REQUIRES(
        context, rank > 1,
        errors::InvalidArgument(
            "Rank of input SparseTensor should be > 1, but saw rank: ", rank));

    TensorShape tensor_input_shape(input_shape->vec<int64>());
    gtl::InlinedVector<int64, 8> std_order(rank);
    std::iota(std_order.begin(), std_order.end(), 0);
    SparseTensor input_st;
    OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
                                                 tensor_input_shape, std_order,
                                                 &input_st));

    auto input_shape_t = input_shape->vec<int64>();
    const int64 N = input_shape_t(0);

    Tensor sparse_handles(DT_INT64, TensorShape({N}));
    auto sparse_handles_t = sparse_handles.vec<int64>();

    OP_REQUIRES_OK(context, input_st.IndicesValid());

    // We can generate the output shape proto string now, for all
    // minibatch entries.
    TensorShape output_shape;
    OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
                                input_shape_t.data() + 1,
                                input_shape->NumElements() - 1, &output_shape));

    // Get groups by minibatch dimension
    std::unordered_set<int64> visited;
    sparse::GroupIterable minibatch = input_st.group({0});
    for (const auto& subset : minibatch) {
      const int64 b = subset.group()[0];
      visited.insert(b);
      OP_REQUIRES(
          context, b > -1 && b < N,
          errors::InvalidArgument(
              "Received unexpected column 0 value in input SparseTensor: ", b,
              " < 0 or >= N (= ", N, ")"));

      const auto indices = subset.indices();
      const auto values = subset.values<T>();
      const int64 num_entries = values.size();

      Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
      Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});

      auto output_indices_t = output_indices.matrix<int64>();
      auto output_values_t = output_values.vec<T>();

      for (int i = 0; i < num_entries; ++i) {
        for (int d = 1; d < rank; ++d) {
          output_indices_t(i, d - 1) = indices(i, d);
        }
        output_values_t(i) = values(i);
      }

      SparseTensor st_i;
      OP_REQUIRES_OK(context,
                     SparseTensor::Create(output_indices, output_values,
                                          output_shape, &st_i));
      int64 handle;
      OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
      sparse_handles_t(b) = handle;
    }

    // Fill in any gaps; we must provide an empty ST for batch entries
    // the grouper didn't find.
    if (visited.size() < N) {
      Tensor empty_indices(DT_INT64, {0, rank - 1});
      Tensor empty_values(DataTypeToEnum<T>::value, {0});
      SparseTensor empty_st;
      OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
                                                   output_shape, &empty_st));

      for (int64 b = 0; b < N; ++b) {
        // We skipped this batch entry.
        if (visited.find(b) == visited.end()) {
          int64 handle;
          OP_REQUIRES_OK(context,
                         map->AddSparseTensor(context, empty_st, &handle));
          sparse_handles_t(b) = handle;
        }
      }
    }

    context->set_output(0, sparse_handles);
  }
};

#define REGISTER_KERNELS(type)                              \
  REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
                              .Device(DEVICE_CPU)           \
                              .TypeConstraint<type>("T"),   \
                          AddManySparseToTensorsMapOp<type>)

TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

template <typename T>
class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
 public:
  explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
      : SparseTensorAccessingOp(context) {}

  void Compute(OpKernelContext* context) override {
    SparseTensorsMap* map = nullptr;
    OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));

    const Tensor& sparse_handles = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
                errors::InvalidArgument(
                    "sparse_handles should be a vector but received shape ",
                    sparse_handles.shape().DebugString()));

    int64 N = sparse_handles.shape().dim_size(0);

    OP_REQUIRES(
        context, N > 0,
        errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
                                "but input matrix has 0 rows"));

    std::vector<Tensor> indices_to_concat;
    std::vector<Tensor> values_to_concat;
    std::vector<TensorShape> shapes_to_concat;

    const auto& sparse_handles_t = sparse_handles.vec<int64>();

    std::vector<SparseTensor> sparse_tensors;

    OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
                                context, sparse_handles_t, &sparse_tensors));

    for (int64 i = 0; i < N; ++i) {
      const SparseTensor& st = sparse_tensors[i];
      const Tensor& output_indices = st.indices();
      const Tensor& output_values = st.values();
      const auto output_shape = st.shape();

      OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
                  errors::InvalidArgument(
                      "Expected sparse_handles[", i,
                      "] to represent an index matrix but received shape ",
                      output_indices.shape().DebugString()));
      OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
                  errors::InvalidArgument(
                      "Expected sparse_handles[", i,
                      "] to represent a values vector but received shape ",
                      output_values.shape().DebugString()));
      OP_REQUIRES(
          context, DataTypeToEnum<T>::value == output_values.dtype(),
          errors::InvalidArgument(
              "Requested SparseTensor of type ",
              DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
              "].values.dtype() == ", DataTypeString(output_values.dtype())));

      int64 num_entries = output_indices.dim_size(0);
      OP_REQUIRES(context, num_entries == output_values.dim_size(0),
                  errors::InvalidArgument(
                      "Expected row counts of SparseTensor[", i,
                      "].indices and SparseTensor[", i,
                      "].values to match but they do not: ", num_entries,
                      " vs. ", output_values.dim_size(0)));
      int rank = output_indices.dim_size(1);
      OP_REQUIRES(
          context, rank == output_shape.size(),
          errors::InvalidArgument("Expected column counts of SparseTensor[", i,
                                  "].indices to match size of SparseTensor[", i,
                                  "].shape "
                                  "but they do not: ",
                                  rank, " vs. ", output_shape.size()));

      // Now we expand each SparseTensors' indices and shape by
      // prefixing a dimension
      Tensor expanded_indices(
          DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
      Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
      const auto& output_indices_t = output_indices.matrix<int64>();
      auto expanded_indices_t = expanded_indices.matrix<int64>();
      auto expanded_shape_t = expanded_shape.vec<int64>();
      expanded_indices_t.chip<1>(0).setZero();
      Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
      Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
      expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
      expanded_shape_t(0) = 1;
      // TODO: copy shape from TensorShape to &expanded_shape_t(1)
      // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
      for (int i = 0; i < rank; ++i) {
        expanded_shape_t(i + 1) = output_shape[i];
      }
      TensorShape expanded_tensor_shape(expanded_shape_t);

      indices_to_concat.push_back(std::move(expanded_indices));
      values_to_concat.push_back(output_values);
      shapes_to_concat.push_back(std::move(expanded_tensor_shape));
    }

    int rank = -1;
    for (int i = 0; i < N; ++i) {
      if (rank < 0) rank = shapes_to_concat[i].dims();
      OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
                  errors::InvalidArgument(
                      "Inconsistent rank across SparseTensors: rank prior to "
                      "SparseTensor[",
                      i, "] was: ", rank, " but rank of SparseTensor[", i,
                      "] is: ", shapes_to_concat[i].dims()));
    }

    // SparseTensor::Concat requires consistent shape for all but the
    // primary order dimension (dimension 0 in this case).  So we get
    // the maximum value across all the input SparseTensors for each
    // dimension and use that.
    TensorShape preconcat_shape(shapes_to_concat[0]);
    for (int i = 0; i < N; ++i) {
      for (int d = 0; d < rank; ++d) {
        preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
                                            shapes_to_concat[i].dim_size(d)));
      }
    }

    // Dimension 0 is the primary dimension.
    gtl::InlinedVector<int64, 8> std_order(rank);
    std::iota(std_order.begin(), std_order.end(), 0);

    std::vector<SparseTensor> tensors_to_concat;
    tensors_to_concat.reserve(N);
    for (int i = 0; i < N; ++i) {
      SparseTensor tensor;
      OP_REQUIRES_OK(context,
                     SparseTensor::Create(std::move(indices_to_concat[i]),
                                          std::move(values_to_concat[i]),
                                          preconcat_shape, std_order, &tensor));
      tensors_to_concat.push_back(std::move(tensor));
    }

    auto output = SparseTensor::Concat<T>(tensors_to_concat);
    Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));

    std::copy_n(output.shape().data(), output.dims(),
                final_output_shape.vec<int64>().data());

    context->set_output(0, output.indices());
    context->set_output(1, output.values());
    context->set_output(2, final_output_shape);
  }
};

#define REGISTER_KERNELS(type)                                 \
  REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
                              .Device(DEVICE_CPU)              \
                              .TypeConstraint<type>("dtype"),  \
                          TakeManySparseFromTensorsMapOp<type>)

TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

}  // namespace tensorflow