diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2016-09-27 09:45:54 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-27 11:02:26 -0700 |
commit | 1a295f88ecc387c03b9de28b4c948c1d2a2c331e (patch) | |
tree | c5c2060354b5fdec36806af430c21b13fb6696a1 /tensorflow/core/util/sparse/dim_comparator.h | |
parent | 86226d50154b004eca3451cd4839e782f6bf2664 (diff) |
Add benchmark for util/SparseTensor's SparseReorder, and optimize SparseTensor::Reorder.
Change: 134425452
Diffstat (limited to 'tensorflow/core/util/sparse/dim_comparator.h')
-rw-r--r-- | tensorflow/core/util/sparse/dim_comparator.h | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h index 8912186157..0c7412d2db 100644 --- a/tensorflow/core/util/sparse/dim_comparator.h +++ b/tensorflow/core/util/sparse/dim_comparator.h @@ -46,14 +46,14 @@ class DimComparator { public: typedef typename gtl::ArraySlice<int64> VarDimArray; - inline DimComparator(const TTypes<int64>::Matrix& ix, - const VarDimArray& order, int dims) - : ix_(ix), order_(order), dims_(dims) { + DimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order, + const TensorShape& shape) + : ix_(ix), order_(order), dims_(shape.dims()) { CHECK_GT(order.size(), size_t{0}) << "Must order using at least one index"; - CHECK_LE(order.size(), dims_) << "Can only sort up to dims"; + CHECK_LE(order.size(), shape.dims()) << "Can only sort up to dims"; for (size_t d = 0; d < order.size(); ++d) { CHECK_GE(order[d], 0); - CHECK_LT(order[d], dims); + CHECK_LT(order[d], shape.dims()); } } @@ -84,9 +84,33 @@ class DimComparator { return 0; } + protected: const TTypes<int64>::Matrix ix_; const VarDimArray order_; const int dims_; + const std::vector<int64>* ix_order_; +}; + +template <int ORDER_DIM> +class FixedDimComparator : DimComparator { + public: + FixedDimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order, + const TensorShape& shape) + : DimComparator(ix, order, shape) { + CHECK_EQ(order.size(), ORDER_DIM); + } + inline bool operator()(const int64 i, const int64 j) const { + bool value = false; + for (int di = 0; di < ORDER_DIM; ++di) { + const int64 d = order_[di]; + if (ix_(i, d) < ix_(j, d)) { + value = true; + break; + } + if (ix_(i, d) > ix_(j, d)) break; + } + return value; + } }; } // namespace sparse |