aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index e259daaba4..98df0844ea 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/platform/prefetch.h"
namespace tensorflow {
@@ -254,9 +255,48 @@ struct BatchSelectFunctorBase {
}
};
+// A fast implementation on CPU, using loop to get rid of broadcasting.
template <typename T>
-struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> {
+struct BatchSelectFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T>::Matrix output_flat_outer_dims,
+ TTypes<bool>::ConstVec cond_vec,
+ typename TTypes<T>::ConstMatrix then_flat_outer_dims,
+ typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
+ const size_t batch = cond_vec.size();
+ const size_t batch_size = then_flat_outer_dims.size() / batch;
+ T* output = output_flat_outer_dims.data();
+ const bool* c = cond_vec.data();
+ const T* t = then_flat_outer_dims.data();
+ const T* e = else_flat_outer_dims.data();
+
+ auto work = [batch_size, output, c, t, e](int64 start, int64 end) {
+ for (size_t i = start; i < end; ++i) {
+ size_t offset = i * batch_size;
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&t[offset + batch_size]));
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&e[offset + batch_size]));
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&c[i + 1]));
+ if (c[i]) {
+ for (size_t j = 0; j < batch_size; ++j) {
+ output[offset + j] = t[offset + j];
+ }
+ } else {
+ for (size_t j = 0; j < batch_size; ++j) {
+ output[offset + j] = e[offset + j];
+ }
+ }
+ }
+ };
+ auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes
+ sizeof(T) * batch_size, // st bytes
+ batch_size); // compute cycles
+ d.parallelFor(batch, cost, work);
+ }
};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct BatchSelectFunctor<SYCLDevice, T>