diff options
Diffstat (limited to 'tensorflow/core/kernels/concat_lib_cpu.h')
-rw-r--r-- | tensorflow/core/kernels/concat_lib_cpu.h | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h index 9d37cafb4e..6a933efde4 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.h +++ b/tensorflow/core/kernels/concat_lib_cpu.h @@ -126,4 +126,39 @@ void ConcatCPUImpl( cost_per_unit, work); } +#ifdef TENSORFLOW_USE_SYCL +template <typename T, typename ElementCopier> +void ConcatSYCLImpl( + const Eigen::SyclDevice& d, + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& + inputs, + int64 cost_per_unit, ElementCopier copier, + typename TTypes<T, 2>::Matrix* output) { + size_t num_inputs = inputs.size(); + + std::vector<ptrdiff_t> sizes; + sizes.reserve(num_inputs); + int64 row_size = 0; + for (const auto& input : inputs) { + sizes.push_back(input->dimension(1)); + row_size += sizes.back(); + } + + T* out = &(*output)(0, 0); + std::vector<const T*> inp; + inp.reserve(num_inputs); + for (const auto& input : inputs) { + inp.push_back(&(*input)(0, 0)); + } + const int64 dim0 = output->dimension(0); + for (int64 i = 0; i < dim0; ++i) { + for (int64 j = 0; j < num_inputs; ++j) { + auto size = sizes[j]; + d.memcpy(out, inp[j], size * sizeof(T)); + out += size; + inp[j] += size; + } + } +} +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |