diff options
Diffstat (limited to 'tensorflow/core/util/sparse/dim_comparator.h')
-rw-r--r-- | tensorflow/core/util/sparse/dim_comparator.h | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h new file mode 100644 index 0000000000..57473867cf --- /dev/null +++ b/tensorflow/core/util/sparse/dim_comparator.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ +#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +///////////////// +// DimComparator +///////////////// +// +// Helper class, mainly used by the IndexSortOrder. This comparator +// can be passed to e.g. std::sort, or any other sorter, to sort two +// rows of an index matrix according to the dimension(s) of interest. +// The dimensions to sort by are passed to the constructor as "order". +// +// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}. +// operator() compares +// IX(ai,2) < IX(bi,2). +// If IX(ai,2) == IX(bi,2), it compares +// IX(ai,1) < IX(bi,1). +// +// This can be used to sort a vector of row indices into IX according to +// the values in IX in particular columns (dimensions) of interest. +class DimComparator { + public: + typedef typename gtl::ArraySlice<int64> VarDimArray; + + inline DimComparator(const TTypes<int64>::Matrix& ix, + const VarDimArray& order, int dims) + : ix_(ix), order_(order), dims_(dims) { + CHECK_GT(order.size(), 0) << "Must order using at least one index"; + CHECK_LE(order.size(), dims_) << "Can only sort up to dims"; + for (size_t d = 0; d < order.size(); ++d) { + CHECK_GE(order[d], 0); + CHECK_LT(order[d], dims); + } + } + + inline bool operator()(const int64 i, const int64 j) const { + for (int di = 0; di < dims_; ++di) { + const int64 d = order_[di]; + if (ix_(i, d) < ix_(j, d)) return true; + if (ix_(i, d) > ix_(j, d)) return false; + } + return false; + } + + const TTypes<int64>::Matrix ix_; + const VarDimArray order_; + const int dims_; +}; + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ |