diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_select.cc | 59 |
1 files changed, 51 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index add26b5ac8..709628da13 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -28,6 +28,10 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + template <typename Device, typename T> class SelectOp : public OpKernel { public: @@ -169,12 +173,24 @@ REGISTER_SELECT_GPU(complex128); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +// Registration of the SYCL implementations. +#define REGISTER_SELECT_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SelectOp<SYCLDevice, type>); + +REGISTER_SELECT_SYCL(float); +REGISTER_SELECT_SYCL(int32); +#undef REGISTER_SELECT_SYCL +#endif // TENSORFLOW_USE_SYCL + namespace functor { // CPU Specializations of Select functors. -template <typename T> -struct SelectFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, +template <typename Device, typename T> +struct SelectFunctorBase { + void operator()(const Device& d, typename TTypes<T>::Flat out, typename TTypes<bool>::ConstFlat cond_flat, typename TTypes<T>::ConstFlat then_flat, typename TTypes<T>::ConstFlat else_flat) { @@ -182,10 +198,18 @@ struct SelectFunctor<CPUDevice, T> { } }; -// CPU Specializations of Select functors with scalar template <typename T> -struct SelectScalarFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, +struct SelectFunctor<CPUDevice, T> + : SelectFunctorBase<CPUDevice, T> {}; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct SelectFunctor<SYCLDevice, T> + : SelectFunctorBase<SYCLDevice, T> {}; +#endif // TENSORFLOW_USE_SYCL + +template <typename Device, typename T> +struct SelectScalarFunctorBase { + void operator()(const Device& d, typename TTypes<T>::Flat out, TTypes<bool>::ConstScalar cond, typename TTypes<T>::ConstFlat then_flat, typename TTypes<T>::ConstFlat else_flat) { @@ -193,9 +217,19 @@ struct SelectScalarFunctor<CPUDevice, T> { } }; +// CPU Specializations of Select functors with scalar template <typename T> -struct BatchSelectFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, +struct SelectScalarFunctor<CPUDevice, T> + : SelectScalarFunctorBase<CPUDevice, T> {}; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct SelectScalarFunctor<SYCLDevice, T> + : SelectScalarFunctorBase<SYCLDevice, T> {}; +#endif // TENSORFLOW_USE_SYCL + +template <typename Device, typename T> +struct BatchSelectFunctorBase { + void operator()(const Device& d, typename TTypes<T>::Matrix output_flat_outer_dims, TTypes<bool>::ConstVec cond_vec, typename TTypes<T>::ConstMatrix then_flat_outer_dims, @@ -220,6 +254,15 @@ struct BatchSelectFunctor<CPUDevice, T> { } }; +template <typename T> +struct BatchSelectFunctor<CPUDevice, T> + : BatchSelectFunctorBase<CPUDevice, T> {}; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct BatchSelectFunctor<SYCLDevice, T> + : BatchSelectFunctorBase<SYCLDevice, T> {}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow |