diff options
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_cpu.h')
-rw-r--r-- | tensorflow/core/kernels/aggregate_ops_cpu.h | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h index ba5ebb7f0f..dfa3fe585e 100644 --- a/tensorflow/core/kernels/aggregate_ops_cpu.h +++ b/tensorflow/core/kernels/aggregate_ops_cpu.h @@ -23,6 +23,10 @@ limitations under the License. typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { // Partial specializations for a CPUDevice, that uses the Eigen implementation @@ -133,6 +137,115 @@ struct Add9Functor<CPUDevice, T> { } }; +#ifdef TENSORFLOW_USE_SYCL +// Partial specializations for a SYCLDevice, that uses the Eigen implementation +// from AddNEigenImpl. +template <typename T> +struct Add2Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2) { + Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2); + } +}; +template <typename T> +struct Add3Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3) { + Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3); + } +}; +template <typename T> +struct Add4Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4) { + Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4); + } +}; +template <typename T> +struct Add5Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5) { + Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); + } +}; +template <typename T> +struct Add6Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, + typename TTypes<T>::ConstFlat in6) { + Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); + } +}; +template <typename T> +struct Add7Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, + typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7) { + Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7); + } +}; + +template <typename T> +struct Add8Functor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { + Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template <typename T> +struct Add8pFunctor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { + Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template <typename T> +struct Add9Functor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, + typename TTypes<T>::ConstFlat in9) { + Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8, in9); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow |