aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-12 22:11:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 22:15:49 -0700
commit845aaec5ec2191f2708247a09d9bad37f012f536 (patch)
tree9f3fec49c0d016ba747c6c3b570c73d27193b853 /tensorflow/core/util
parent725dfe9cd0eef3f4b858eaeda38728813c99a210 (diff)
[SparseTensor] Avoid calling `Tensor::matrix<int64>()` for each element of a SparseTensor when iterating over it.
PiperOrigin-RevId: 212758856
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc10
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h4
2 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933051..546b0a833c 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@ namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
- int64 N = iter_->ix_.dim_size(0);
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
+ const int64 N = ix_t.dimension(0);
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
@@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
std::vector<int64> Group::group() const {
std::vector<int64> g;
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
@@ -62,8 +62,8 @@ std::vector<int64> Group::group() const {
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
- return TTypes<int64>::UnalignedConstMatrix(
- &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+ return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+ next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6116..14610c61d9 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@ class GroupIterable {
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix),
+ ix_matrix_(ix_.matrix<int64>()),
vals_(vals),
dims_(dims),
group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@ class GroupIterable {
private:
friend class Group;
- Tensor ix_;
+ const Tensor ix_;
+ const TTypes<int64>::ConstMatrix ix_matrix_;
Tensor vals_;
const int dims_;
const gtl::InlinedVector<int64, 8> group_dims_;