diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 56 |
1 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 98df0844ea..d6988a562c 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL +namespace functor { +template <typename Device, typename T> +struct SelectScalarHandler; +} // namespace functor + template <typename Device, typename T> class SelectOp : public OpKernel { public: @@ -131,16 +136,8 @@ class SelectOp : public OpKernel { then->shape().DebugString(), " vs. ", else_->shape().DebugString())); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( - {"t", "e"}, "output", then->shape(), &output)); - - if (output->NumElements() > 0) { - functor::SelectScalarFunctor<Device, T> func; - TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); - func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, - then->flat<T>(), else_->flat<T>()); - } + functor::SelectScalarHandler<Device, T> handler; + handler(ctx, cond, then, else_); } private: @@ -209,6 +206,40 @@ struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {}; #endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> +struct SelectScalarHandler { + void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, + const Tensor* else_) { + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", then->shape(), &output)); + + if (output->NumElements() > 0) { + functor::SelectScalarFunctor<Device, T> func; + TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); + func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, + then->flat<T>(), else_->flat<T>()); + } + } +}; + +// Specilization for CPU device. Forward input to output depending on the `cond` +// value. +// TODO(sjhwang): Consider specializing for GPUDevice as well by using +// GPUDevice::memcpyDeviceToHost() to fetch bool value. +template <typename T> +struct SelectScalarHandler<CPUDevice, T> { + void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, + const Tensor* else_) { + if (cond->scalar<bool>()()) { + OP_REQUIRES_OK(ctx, ctx->set_output("output", *then)); + } else { + OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_)); + } + } +}; + +#ifdef TENSORFLOW_USE_SYCL +template <typename Device, typename T> struct SelectScalarFunctorBase { void operator()(const Device& d, typename TTypes<T>::Flat out, TTypes<bool>::ConstScalar cond, @@ -218,11 +249,6 @@ struct SelectScalarFunctorBase { } }; -// CPU Specializations of Select functors with scalar -template <typename T> -struct SelectScalarFunctor<CPUDevice, T> - : SelectScalarFunctorBase<CPUDevice, T> {}; -#ifdef TENSORFLOW_USE_SYCL template <typename T> struct SelectScalarFunctor<SYCLDevice, T> : SelectScalarFunctorBase<SYCLDevice, T> {}; |