diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index e259daaba4..98df0844ea 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/platform/prefetch.h" namespace tensorflow { @@ -254,9 +255,48 @@ struct BatchSelectFunctorBase { } }; +// A fast implementation on CPU, using loop to get rid of broadcasting. template <typename T> -struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> { +struct BatchSelectFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T>::Matrix output_flat_outer_dims, + TTypes<bool>::ConstVec cond_vec, + typename TTypes<T>::ConstMatrix then_flat_outer_dims, + typename TTypes<T>::ConstMatrix else_flat_outer_dims) { + const size_t batch = cond_vec.size(); + const size_t batch_size = then_flat_outer_dims.size() / batch; + T* output = output_flat_outer_dims.data(); + const bool* c = cond_vec.data(); + const T* t = then_flat_outer_dims.data(); + const T* e = else_flat_outer_dims.data(); + + auto work = [batch_size, output, c, t, e](int64 start, int64 end) { + for (size_t i = start; i < end; ++i) { + size_t offset = i * batch_size; + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&t[offset + batch_size])); + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&e[offset + batch_size])); + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&c[i + 1])); + if (c[i]) { + for (size_t j = 0; j < batch_size; ++j) { + output[offset + j] = t[offset + j]; + } + } else { + for (size_t j = 0; j < batch_size; ++j) { + output[offset + j] = e[offset + j]; + } + } + } + }; + auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes + sizeof(T) * batch_size, // st bytes + batch_size); // compute cycles + d.parallelFor(batch, cost, work); + } }; + #ifdef TENSORFLOW_USE_SYCL template <typename T> struct BatchSelectFunctor<SYCLDevice, T> |