aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/group_iterator.cc
blob: e153bcdbb4505bef30ac5ad2a02ba35429e13293 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include "tensorflow/core/util/sparse/group_iterator.h"

namespace tensorflow {
namespace sparse {

void GroupIterable::IteratorStep::UpdateEndOfGroup() {
  ++next_loc_;
  int64 N = iter_->ix_.dim_size(0);
  auto ix_t = iter_->ix_.template matrix<int64>();
  while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
    ++next_loc_;
  }
}

bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
  CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
  return (rhs.loc_ != loc_);
}

GroupIterable::IteratorStep& GroupIterable::IteratorStep::
operator++() {  // prefix ++
  loc_ = next_loc_;
  UpdateEndOfGroup();
  return *this;
}

GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
    int) {  // postfix ++
  IteratorStep lhs(*this);
  ++(*this);
  return lhs;
}

std::vector<int64> Group::group() const {
  std::vector<int64> g;
  auto ix_t = iter_->ix_.template matrix<int64>();
  for (const int d : iter_->group_dims_) {
    g.push_back(ix_t(loc_, d));
  }
  return g;
}

TTypes<int64>::UnalignedConstMatrix Group::indices() const {
  return TTypes<int64>::UnalignedConstMatrix(
      &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
}

}  // namespace sparse
}  // namespace tensorflow