aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/group_iterator.h
blob: 3fa8cb6116f76839e640746ad2c7f097dd672781 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
#define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_

#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace sparse {

class GroupIterable;  // Predeclare GroupIterable for Group.

// This class is returned when dereferencing a GroupIterable iterator.
// It provides the methods group(), indices(), and values(), which
// provide access into the underlying SparseTensor.
class Group {
 public:
  Group(GroupIterable* iter, int64 loc, int64 next_loc)
      : iter_(iter), loc_(loc), next_loc_(next_loc) {}

  std::vector<int64> group() const;
  TTypes<int64>::UnalignedConstMatrix indices() const;
  template <typename T>
  typename TTypes<T>::UnalignedVec values() const;

 private:
  GroupIterable* iter_;
  int64 loc_;
  int64 next_loc_;
};

/////////////////
// GroupIterable
/////////////////
//
// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
//
// Please note: the sparse_tensor should already be ordered according
// to {dim0, dim1, ...}.  Otherwise this iteration will return invalid groups.
//
// Allows grouping and iteration of the SparseTensor according to the
// subset of dimensions provided to the group call.
//
// The actual grouping dimensions are stored in the
// internal vector group_dims_.  Iterators inside the iterable provide
// the three methods:
//
// *  group(): returns a vector with the current group dimension values.
// *  indices(): a map of index, providing the indices in
//    this group.
// *  values(): a map of values, providing the values in
//    this group.
//
// To iterate across GroupIterable, see examples in README.md.
//

// Forward declaration of SparseTensor
class GroupIterable {
 public:
  typedef gtl::ArraySlice<int64> VarDimArray;

  GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
      : ix_(ix),
        vals_(vals),
        dims_(dims),
        group_dims_(group_dims.begin(), group_dims.end()) {}

  class IteratorStep;

  IteratorStep begin() { return IteratorStep(this, 0); }
  IteratorStep at(int64 loc) {
    CHECK(loc >= 0 && loc <= ix_.dim_size(0))
        << "loc provided must lie between 0 and " << ix_.dim_size(0);
    return IteratorStep(this, loc);
  }
  IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }

  template <typename TIX>
  inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
    bool matches = true;
    for (int d : group_dims_) {
      if (ix(loc_a, d) != ix(loc_b, d)) {
        matches = false;
      }
    }
    return matches;
  }

  class IteratorStep {
   public:
    IteratorStep(GroupIterable* iter, int64 loc)
        : iter_(iter), loc_(loc), next_loc_(loc_) {
      UpdateEndOfGroup();
    }

    void UpdateEndOfGroup();
    bool operator!=(const IteratorStep& rhs) const;
    bool operator==(const IteratorStep& rhs) const;
    IteratorStep& operator++();    // prefix ++
    IteratorStep operator++(int);  // postfix ++
    Group operator*() const { return Group(iter_, loc_, next_loc_); }
    int64 loc() const { return loc_; }

   private:
    GroupIterable* iter_;
    int64 loc_;
    int64 next_loc_;
  };

 private:
  friend class Group;
  Tensor ix_;
  Tensor vals_;
  const int dims_;
  const gtl::InlinedVector<int64, 8> group_dims_;
};

// Implementation of Group::values<T>()
template <typename T>
typename TTypes<T>::UnalignedVec Group::values() const {
  return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
                                          next_loc_ - loc_);
}

}  // namespace sparse
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_