#include "tensorflow/core/util/sparse/group_iterator.h" namespace tensorflow { namespace sparse { void GroupIterable::IteratorStep::UpdateEndOfGroup() { ++next_loc_; int64 N = iter_->ix_.dim_size(0); auto ix_t = iter_->ix_.template matrix(); while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { ++next_loc_; } } bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const { CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators"; return (rhs.loc_ != loc_); } GroupIterable::IteratorStep& GroupIterable::IteratorStep:: operator++() { // prefix ++ loc_ = next_loc_; UpdateEndOfGroup(); return *this; } GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( int) { // postfix ++ IteratorStep lhs(*this); ++(*this); return lhs; } std::vector Group::group() const { std::vector g; auto ix_t = iter_->ix_.template matrix(); for (const int d : iter_->group_dims_) { g.push_back(ix_t(loc_, d)); } return g; } TTypes::UnalignedConstMatrix Group::indices() const { return TTypes::UnalignedConstMatrix( &(iter_->ix_.matrix()(loc_, 0)), next_loc_ - loc_, iter_->dims_); } } // namespace sparse } // namespace tensorflow