blob: e153bcdbb4505bef30ac5ad2a02ba35429e13293 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
#include "tensorflow/core/util/sparse/group_iterator.h"
namespace tensorflow {
namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
int64 N = iter_->ix_.dim_size(0);
auto ix_t = iter_->ix_.template matrix<int64>();
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
}
bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
return (rhs.loc_ != loc_);
}
GroupIterable::IteratorStep& GroupIterable::IteratorStep::
operator++() { // prefix ++
loc_ = next_loc_;
UpdateEndOfGroup();
return *this;
}
GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
int) { // postfix ++
IteratorStep lhs(*this);
++(*this);
return lhs;
}
std::vector<int64> Group::group() const {
std::vector<int64> g;
auto ix_t = iter_->ix_.template matrix<int64>();
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
return g;
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
return TTypes<int64>::UnalignedConstMatrix(
&(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
} // namespace tensorflow
|