diff options
Diffstat (limited to 'tensorflow/core/kernels/concat_lib_cpu.cc')
-rw-r--r-- | tensorflow/core/kernels/concat_lib_cpu.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index f83aed6aef..f89948350c 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -74,4 +74,23 @@ REGISTER(qint16) REGISTER(qint32) REGISTER(bfloat16) +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +void ConcatSYCL(const Eigen::SyclDevice& d, + const std::vector< + std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs, + typename TTypes<T, 2>::Matrix* output) { + ConcatSYCLImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier<T>(), + output); +} +#define REGISTER_SYCL(T) \ + template void ConcatSYCL<T>( \ + const Eigen::SyclDevice&, \ + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \ + typename TTypes<T, 2>::Matrix* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL) + +#undef REGISTER_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |