aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-07-26 11:27:40 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-07-26 12:20:47 +0800
commit23f826271a5956982df17980bca3ac7513ec4ee4 (patch)
tree8aabd4443c311164af9431aae66a29c70a4ce2d5 /tensorflow/core/kernels/cwise_op_select.cc
parent15b155e929f2eb3e30c1194fa9afc1ea40e330a4 (diff)
A faster BatchSelectFunctor for tf.where on CPU.
Op 'tf.where(c, t, e)' supports that 't' and 'e' are N-D tensors while 'c' is a 1D tensor, which would call BatchSelectFunctor to get the result. But its basic implementation broadcasts 'c' to the same dimension with 't' and 'e', which would get bad efficiency on CPU for large tensors. Here a loop-based implementation would be adopted to make this operation faster on CPU.
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index e259daaba4..0d6d83fc3a 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/platform/prefetch.h"
namespace tensorflow {
@@ -254,9 +255,48 @@ struct BatchSelectFunctorBase {
}
};
+// A fast implementation on CPU, using loop to get rid of broadcasting.
template <typename T>
-struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> {
+struct BatchSelectFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T>::Matrix output_flat_outer_dims,
+ TTypes<bool>::ConstVec cond_vec,
+ typename TTypes<T>::ConstMatrix then_flat_outer_dims,
+ typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
+ const size_t batch = cond_vec.size();
+ const size_t batch_size = then_flat_outer_dims.size() / batch;
+ T* output = output_flat_outer_dims.data();
+ const bool* c = cond_vec.data();
+ const T* t = then_flat_outer_dims.data();
+ const T* e = else_flat_outer_dims.data();
+
+ auto work = [batch_size, output, c, t, e](int64 start, int64 end) {
+ for (size_t i = start; i < end; ++i) {
+ size_t offset = i * batch_size;
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&t[offset + batch_size]));
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&e[offset + batch_size]));
+ port::prefetch<port::PREFETCH_HINT_NTA>(
+ reinterpret_cast<const void*>(&c[i + 1]));
+ if (c[i]) {
+ for (size_t j = 0; j < batch_size; ++j) {
+ output[offset + j] = t[offset + j];
+ }
+ } else {
+ for (size_t j = 0; j < batch_size; ++j) {
+ output[offset + j] = e[offset + j];
+ }
+ }
+ }
+ };
+ auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes
+ sizeof(T) * batch_size, // st bytes
+ batch_size); // compute cycles
+ d.parallelFor(batch, cost, work);
+ }
};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct BatchSelectFunctor<SYCLDevice, T>