aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/group_iterator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/sparse/group_iterator.cc')
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
new file mode 100644
index 0000000000..e153bcdbb4
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/util/sparse/group_iterator.h"
+
+namespace tensorflow {
+namespace sparse {
+
+void GroupIterable::IteratorStep::UpdateEndOfGroup() {
+ ++next_loc_;
+ int64 N = iter_->ix_.dim_size(0);
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
+ ++next_loc_;
+ }
+}
+
+bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
+ CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
+ return (rhs.loc_ != loc_);
+}
+
+GroupIterable::IteratorStep& GroupIterable::IteratorStep::
+operator++() { // prefix ++
+ loc_ = next_loc_;
+ UpdateEndOfGroup();
+ return *this;
+}
+
+GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
+ int) { // postfix ++
+ IteratorStep lhs(*this);
+ ++(*this);
+ return lhs;
+}
+
+std::vector<int64> Group::group() const {
+ std::vector<int64> g;
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ for (const int d : iter_->group_dims_) {
+ g.push_back(ix_t(loc_, d));
+ }
+ return g;
+}
+
+TTypes<int64>::UnalignedConstMatrix Group::indices() const {
+ return TTypes<int64>::UnalignedConstMatrix(
+ &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+}
+
+} // namespace sparse
+} // namespace tensorflow