aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/concat_lib_cpu.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/concat_lib_cpu.h')
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.h35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h
index 9d37cafb4e..6a933efde4 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.h
+++ b/tensorflow/core/kernels/concat_lib_cpu.h
@@ -126,4 +126,39 @@ void ConcatCPUImpl(
cost_per_unit, work);
}
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename ElementCopier>
+void ConcatSYCLImpl(
+ const Eigen::SyclDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ int64 cost_per_unit, ElementCopier copier,
+ typename TTypes<T, 2>::Matrix* output) {
+ size_t num_inputs = inputs.size();
+
+ std::vector<ptrdiff_t> sizes;
+ sizes.reserve(num_inputs);
+ int64 row_size = 0;
+ for (const auto& input : inputs) {
+ sizes.push_back(input->dimension(1));
+ row_size += sizes.back();
+ }
+
+ T* out = &(*output)(0, 0);
+ std::vector<const T*> inp;
+ inp.reserve(num_inputs);
+ for (const auto& input : inputs) {
+ inp.push_back(&(*input)(0, 0));
+ }
+ const int64 dim0 = output->dimension(0);
+ for (int64 i = 0; i < dim0; ++i) {
+ for (int64 j = 0; j < num_inputs; ++j) {
+ auto size = sizes[j];
+ d.memcpy(out, inp[j], size * sizeof(T));
+ out += size;
+ inp[j] += size;
+ }
+ }
+}
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow