aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/dim_comparator.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-09-27 09:45:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-27 11:02:26 -0700
commit1a295f88ecc387c03b9de28b4c948c1d2a2c331e (patch)
treec5c2060354b5fdec36806af430c21b13fb6696a1 /tensorflow/core/util/sparse/dim_comparator.h
parent86226d50154b004eca3451cd4839e782f6bf2664 (diff)
Add benchmark for util/SparseTensor's SparseReorder, and optimize SparseTensor::Reorder.
Change: 134425452
Diffstat (limited to 'tensorflow/core/util/sparse/dim_comparator.h')
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h34
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
index 8912186157..0c7412d2db 100644
--- a/tensorflow/core/util/sparse/dim_comparator.h
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -46,14 +46,14 @@ class DimComparator {
public:
typedef typename gtl::ArraySlice<int64> VarDimArray;
- inline DimComparator(const TTypes<int64>::Matrix& ix,
- const VarDimArray& order, int dims)
- : ix_(ix), order_(order), dims_(dims) {
+ DimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
+ const TensorShape& shape)
+ : ix_(ix), order_(order), dims_(shape.dims()) {
CHECK_GT(order.size(), size_t{0}) << "Must order using at least one index";
- CHECK_LE(order.size(), dims_) << "Can only sort up to dims";
+ CHECK_LE(order.size(), shape.dims()) << "Can only sort up to dims";
for (size_t d = 0; d < order.size(); ++d) {
CHECK_GE(order[d], 0);
- CHECK_LT(order[d], dims);
+ CHECK_LT(order[d], shape.dims());
}
}
@@ -84,9 +84,33 @@ class DimComparator {
return 0;
}
+ protected:
const TTypes<int64>::Matrix ix_;
const VarDimArray order_;
const int dims_;
+ const std::vector<int64>* ix_order_;
+};
+
+template <int ORDER_DIM>
+class FixedDimComparator : DimComparator {
+ public:
+ FixedDimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
+ const TensorShape& shape)
+ : DimComparator(ix, order, shape) {
+ CHECK_EQ(order.size(), ORDER_DIM);
+ }
+ inline bool operator()(const int64 i, const int64 j) const {
+ bool value = false;
+ for (int di = 0; di < ORDER_DIM; ++di) {
+ const int64 d = order_[di];
+ if (ix_(i, d) < ix_(j, d)) {
+ value = true;
+ break;
+ }
+ if (ix_(i, d) > ix_(j, d)) break;
+ }
+ return value;
+ }
};
} // namespace sparse