aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/where_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/where_op.h')
-rw-r--r--tensorflow/core/kernels/where_op.h65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h
new file mode 100644
index 0000000000..c7b835d02f
--- /dev/null
+++ b/tensorflow/core/kernels/where_op.h
@@ -0,0 +1,65 @@
+#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_
+#define TENSORFLOW_KERNELS_WHERE_OP_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device>
+struct NumTrue {
+ EIGEN_ALWAYS_INLINE static void Compute(
+ const Device& d, typename TTypes<bool>::ConstFlat input,
+ TTypes<int64>::Scalar num_true) {
+ num_true.device(d) = input.template cast<int64>().sum();
+ }
+};
+
+template <typename Device, int NDIM>
+struct Where {
+ EIGEN_ALWAYS_INLINE static void Compute(
+ const Device& d, typename TTypes<bool, NDIM>::ConstTensor input,
+ typename TTypes<int64>::Matrix output) {
+ Eigen::DenseIndex true_n = 0;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> strides;
+
+ // Calculate strides for RowMajor order.
+ EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
+ static_cast<int>(Eigen::RowMajor)),
+ INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
+
+ strides[NDIM - 1] = 1;
+ for (int i = NDIM - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * dims[i + 1];
+ }
+
+ // Note, no bounds checking is done on true_n. It is assumed that
+ // the output was correctly sized via output of NumTrue::Compute.
+ for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
+ if (input.data()[n]) {
+ WriteIndexRowMajor(output, strides, true_n, n);
+ ++true_n;
+ }
+ }
+ }
+
+ EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
+ typename TTypes<int64>::Matrix output,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides,
+ Eigen::DenseIndex true_n, Eigen::DenseIndex index) {
+ for (int i = 0; i < NDIM; ++i) {
+ output(true_n, i) = index / strides[i];
+ index %= strides[i];
+ }
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_WHERE_OP_H_