1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_
#define TENSORFLOW_KERNELS_WHERE_OP_H_
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace functor {
template <typename Device>
struct NumTrue {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, typename TTypes<bool>::ConstFlat input,
TTypes<int64>::Scalar num_true) {
num_true.device(d) = input.template cast<int64>().sum();
}
};
template <typename Device, int NDIM>
struct Where {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, typename TTypes<bool, NDIM>::ConstTensor input,
typename TTypes<int64>::Matrix output) {
Eigen::DenseIndex true_n = 0;
Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
Eigen::DSizes<Eigen::DenseIndex, NDIM> strides;
// Calculate strides for RowMajor order.
EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
static_cast<int>(Eigen::RowMajor)),
INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
strides[NDIM - 1] = 1;
for (int i = NDIM - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dims[i + 1];
}
// Note, no bounds checking is done on true_n. It is assumed that
// the output was correctly sized via output of NumTrue::Compute.
for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
if (input.data()[n]) {
WriteIndexRowMajor(output, strides, true_n, n);
++true_n;
}
}
}
EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
typename TTypes<int64>::Matrix output,
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides,
Eigen::DenseIndex true_n, Eigen::DenseIndex index) {
for (int i = 0; i < NDIM; ++i) {
output(true_n, i) = index / strides[i];
index %= strides[i];
}
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_WHERE_OP_H_
|