aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/group_iterator.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/util/sparse/group_iterator.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/util/sparse/group_iterator.cc')
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
new file mode 100644
index 0000000000..e153bcdbb4
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/util/sparse/group_iterator.h"
+
+namespace tensorflow {
+namespace sparse {
+
+void GroupIterable::IteratorStep::UpdateEndOfGroup() {
+ ++next_loc_;
+ int64 N = iter_->ix_.dim_size(0);
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
+ ++next_loc_;
+ }
+}
+
+bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
+ CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
+ return (rhs.loc_ != loc_);
+}
+
+GroupIterable::IteratorStep& GroupIterable::IteratorStep::
+operator++() { // prefix ++
+ loc_ = next_loc_;
+ UpdateEndOfGroup();
+ return *this;
+}
+
+GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
+ int) { // postfix ++
+ IteratorStep lhs(*this);
+ ++(*this);
+ return lhs;
+}
+
+std::vector<int64> Group::group() const {
+ std::vector<int64> g;
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ for (const int d : iter_->group_dims_) {
+ g.push_back(ix_t(loc_, d));
+ }
+ return g;
+}
+
+TTypes<int64>::UnalignedConstMatrix Group::indices() const {
+ return TTypes<int64>::UnalignedConstMatrix(
+ &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+}
+
+} // namespace sparse
+} // namespace tensorflow