1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace sparse {
class GroupIterable; // Predeclare GroupIterable for Group.
// This class is returned when dereferencing a GroupIterable iterator.
// It provides the methods group(), indices(), and values(), which
// provide access into the underlying SparseTensor.
class Group {
public:
Group(GroupIterable* iter, int64 loc, int64 next_loc)
: iter_(iter), loc_(loc), next_loc_(next_loc) {}
std::vector<int64> group() const;
TTypes<int64>::UnalignedConstMatrix indices() const;
template <typename T>
typename TTypes<T>::UnalignedVec values() const;
private:
GroupIterable* iter_;
int64 loc_;
int64 next_loc_;
};
/////////////////
// GroupIterable
/////////////////
//
// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
//
// Please note: the sparse_tensor should already be ordered according
// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
//
// Allows grouping and iteration of the SparseTensor according to the
// subset of dimensions provided to the group call.
//
// The actual grouping dimensions are stored in the
// internal vector group_dims_. Iterators inside the iterable provide
// the three methods:
//
// * group(): returns a vector with the current group dimension values.
// * indices(): a map of index, providing the indices in
// this group.
// * values(): a map of values, providing the values in
// this group.
//
// To iterate across GroupIterable, see examples in README.md.
//
// Forward declaration of SparseTensor
class GroupIterable {
public:
typedef gtl::ArraySlice<int64> VarDimArray;
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
class IteratorStep;
IteratorStep begin() { return IteratorStep(this, 0); }
IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
template <typename TIX>
inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
bool matches = true;
for (int d : group_dims_) {
if (ix(loc_a, d) != ix(loc_b, d)) {
matches = false;
}
}
return matches;
}
class IteratorStep {
public:
IteratorStep(GroupIterable* iter, int64 loc)
: iter_(iter), loc_(loc), next_loc_(loc_) {
UpdateEndOfGroup();
}
void UpdateEndOfGroup();
bool operator!=(const IteratorStep& rhs) const;
IteratorStep& operator++(); // prefix ++
IteratorStep operator++(int); // postfix ++
Group operator*() const { return Group(iter_, loc_, next_loc_); }
private:
GroupIterable* iter_;
int64 loc_;
int64 next_loc_;
};
private:
friend class Group;
Tensor ix_;
Tensor vals_;
const int dims_;
const VarDimArray group_dims_;
};
// Implementation of Group::values<T>()
template <typename T>
typename TTypes<T>::UnalignedVec Group::values() const {
return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
next_loc_ - loc_);
}
} // namespace sparse
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
|