diff options
author | Derek Murray <mrry@google.com> | 2018-09-12 22:11:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 22:15:49 -0700 |
commit | 845aaec5ec2191f2708247a09d9bad37f012f536 (patch) | |
tree | 9f3fec49c0d016ba747c6c3b570c73d27193b853 /tensorflow/core/util | |
parent | 725dfe9cd0eef3f4b858eaeda38728813c99a210 (diff) |
[SparseTensor] Avoid calling `Tensor::matrix<int64>()` for each element of a SparseTensor when iterating over it.
PiperOrigin-RevId: 212758856
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/sparse/group_iterator.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/util/sparse/group_iterator.h | 4 |
2 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc index 204b933051..546b0a833c 100644 --- a/tensorflow/core/util/sparse/group_iterator.cc +++ b/tensorflow/core/util/sparse/group_iterator.cc @@ -21,8 +21,8 @@ namespace sparse { void GroupIterable::IteratorStep::UpdateEndOfGroup() { ++next_loc_; - int64 N = iter_->ix_.dim_size(0); - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; + const int64 N = ix_t.dimension(0); while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { ++next_loc_; } @@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( std::vector<int64> Group::group() const { std::vector<int64> g; - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; for (const int d : iter_->group_dims_) { g.push_back(ix_t(loc_, d)); } @@ -62,8 +62,8 @@ std::vector<int64> Group::group() const { } TTypes<int64>::UnalignedConstMatrix Group::indices() const { - return TTypes<int64>::UnalignedConstMatrix( - &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_); + return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)), + next_loc_ - loc_, iter_->dims_); } } // namespace sparse diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h index 3fa8cb6116..14610c61d9 100644 --- a/tensorflow/core/util/sparse/group_iterator.h +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -79,6 +79,7 @@ class GroupIterable { GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) : ix_(ix), + ix_matrix_(ix_.matrix<int64>()), vals_(vals), dims_(dims), group_dims_(group_dims.begin(), group_dims.end()) {} @@ -127,7 +128,8 @@ class GroupIterable { private: friend class Group; - Tensor ix_; + const Tensor ix_; + const TTypes<int64>::ConstMatrix ix_matrix_; Tensor vals_; const int dims_; const gtl::InlinedVector<int64, 8> group_dims_; |