aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/aggregate_ops_cpu.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_cpu.h')
-rw-r--r--tensorflow/core/kernels/aggregate_ops_cpu.h113
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h
index ba5ebb7f0f..dfa3fe585e 100644
--- a/tensorflow/core/kernels/aggregate_ops_cpu.h
+++ b/tensorflow/core/kernels/aggregate_ops_cpu.h
@@ -23,6 +23,10 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
namespace tensorflow {
// Partial specializations for a CPUDevice, that uses the Eigen implementation
@@ -133,6 +137,115 @@ struct Add9Functor<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specializations for a SYCLDevice, that uses the Eigen implementation
+// from AddNEigenImpl.
+template <typename T>
+struct Add2Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2);
+ }
+};
+template <typename T>
+struct Add3Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3);
+ }
+};
+template <typename T>
+struct Add4Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4);
+ }
+};
+template <typename T>
+struct Add5Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
+ }
+};
+template <typename T>
+struct Add6Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
+ }
+};
+template <typename T>
+struct Add7Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7);
+ }
+};
+
+template <typename T>
+struct Add8Functor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add8pFunctor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add9Functor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8, in9);
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow