1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/logging.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace sparse {
/////////////////
// DimComparator
/////////////////
//
// Helper class, mainly used by the IndexSortOrder. This comparator
// can be passed to e.g. std::sort, or any other sorter, to sort two
// rows of an index matrix according to the dimension(s) of interest.
// The dimensions to sort by are passed to the constructor as "order".
//
// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}.
// operator() compares
// IX(ai,2) < IX(bi,2).
// If IX(ai,2) == IX(bi,2), it compares
// IX(ai,1) < IX(bi,1).
//
// This can be used to sort a vector of row indices into IX according to
// the values in IX in particular columns (dimensions) of interest.
class DimComparator {
public:
typedef typename gtl::ArraySlice<int64> VarDimArray;
inline DimComparator(const TTypes<int64>::Matrix& ix,
const VarDimArray& order, int dims)
: ix_(ix), order_(order), dims_(dims) {
CHECK_GT(order.size(), 0) << "Must order using at least one index";
CHECK_LE(order.size(), dims_) << "Can only sort up to dims";
for (size_t d = 0; d < order.size(); ++d) {
CHECK_GE(order[d], 0);
CHECK_LT(order[d], dims);
}
}
inline bool operator()(const int64 i, const int64 j) const {
for (int di = 0; di < dims_; ++di) {
const int64 d = order_[di];
if (ix_(i, d) < ix_(j, d)) return true;
if (ix_(i, d) > ix_(j, d)) return false;
}
return false;
}
const TTypes<int64>::Matrix ix_;
const VarDimArray order_;
const int dims_;
};
} // namespace sparse
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
|