diff options
author | 2018-07-26 11:27:40 +0800 | |
---|---|---|
committer | 2018-07-26 12:20:47 +0800 | |
commit | 23f826271a5956982df17980bca3ac7513ec4ee4 (patch) | |
tree | 8aabd4443c311164af9431aae66a29c70a4ce2d5 /tensorflow/core/kernels/cwise_op_select.cc | |
parent | 15b155e929f2eb3e30c1194fa9afc1ea40e330a4 (diff) |
A faster BatchSelectFunctor for tf.where on CPU.
Op 'tf.where(c, t, e)' supports that 't' and 'e' are N-D tensors
while 'c' is a 1D tensor, which would call BatchSelectFunctor to
get the result. But its basic implementation broadcasts 'c' to the
same dimension with 't' and 'e', which would get bad efficiency on
CPU for large tensors. Here a loop-based implementation would be
adopted to make this operation faster on CPU.
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index e259daaba4..0d6d83fc3a 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/platform/prefetch.h" namespace tensorflow { @@ -254,9 +255,48 @@ struct BatchSelectFunctorBase { } }; +// A fast implementation on CPU, using loop to get rid of broadcasting. template <typename T> -struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> { +struct BatchSelectFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T>::Matrix output_flat_outer_dims, + TTypes<bool>::ConstVec cond_vec, + typename TTypes<T>::ConstMatrix then_flat_outer_dims, + typename TTypes<T>::ConstMatrix else_flat_outer_dims) { + const size_t batch = cond_vec.size(); + const size_t batch_size = then_flat_outer_dims.size() / batch; + T* output = output_flat_outer_dims.data(); + const bool* c = cond_vec.data(); + const T* t = then_flat_outer_dims.data(); + const T* e = else_flat_outer_dims.data(); + + auto work = [batch_size, output, c, t, e](int64 start, int64 end) { + for (size_t i = start; i < end; ++i) { + size_t offset = i * batch_size; + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&t[offset + batch_size])); + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&e[offset + batch_size])); + port::prefetch<port::PREFETCH_HINT_NTA>( + reinterpret_cast<const void*>(&c[i + 1])); + if (c[i]) { + for (size_t j = 0; j < batch_size; ++j) { + output[offset + j] = t[offset + j]; + } + } else { + for (size_t j = 0; j < batch_size; ++j) { + output[offset + j] = e[offset + j]; + } + } + } + }; + auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes + sizeof(T) * batch_size, // st bytes + batch_size); // compute cycles + d.parallelFor(batch, cost, work); + } }; + #ifdef TENSORFLOW_USE_SYCL template <typename T> struct BatchSelectFunctor<SYCLDevice, T> |