aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/concat_lib_cpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/concat_lib_cpu.cc')
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index f83aed6aef..f89948350c 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -74,4 +74,23 @@ REGISTER(qint16)
REGISTER(qint32)
REGISTER(bfloat16)
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+void ConcatSYCL(const Eigen::SyclDevice& d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ ConcatSYCLImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier<T>(),
+ output);
+}
+#define REGISTER_SYCL(T) \
+ template void ConcatSYCL<T>( \
+ const Eigen::SyclDevice&, \
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
+ typename TTypes<T, 2>::Matrix* output);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL)
+
+#undef REGISTER_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow