From e073de96dc71a2f720eb80bf11023972e9c10bca Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 30 Nov 2016 21:36:52 -0800 Subject: Moved the MemCopyFunctor back to TensorSyclDevice since it's the only caller and it makes TensorFlow compile again --- .../Eigen/CXX11/src/Tensor/TensorDeviceSycl.h | 27 ++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h index 3d53b40ec..1fd00d4f6 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h @@ -170,6 +170,29 @@ struct SyclDevice { // some runtime conditions that can be applied here EIGEN_STRONG_INLINE bool isDeviceSuitable() const { return true; } + template class MemCopyFunctor { + public: + typedef cl::sycl::accessor read_accessor; + typedef cl::sycl::accessor write_accessor; + + MemCopyFunctor(read_accessor src_acc, write_accessor dst_acc, size_t rng, size_t i, size_t offset): m_src_acc(src_acc), m_dst_acc(dst_acc), m_rng(rng), m_i(i), m_offset(offset) {} + + void operator()(cl::sycl::nd_item<1> itemID) { + auto src_ptr = ConvertToActualTypeSycl(T, m_src_acc); + auto dst_ptr = ConvertToActualTypeSycl(T, m_dst_acc); + auto globalid = itemID.get_global_linear_id(); + if (globalid < m_rng) { + dst_ptr[globalid + m_i] = src_ptr[globalid + m_offset]; + } + } + + private: + read_accessor m_src_acc; + write_accessor m_dst_acc; + size_t m_rng; + size_t m_i; + size_t m_offset; + }; /// the memcpy function template EIGEN_STRONG_INLINE void memcpy(void *dst, const T *src, size_t n) const { @@ -184,7 +207,7 @@ struct SyclDevice { sycl_queue().submit([&](cl::sycl::handler &cgh) { auto src_acc =it1->second.template get_access(cgh); auto dst_acc =it2->second.template get_access(cgh); - cgh.parallel_for(cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), TensorSycl::internal::MemCopyFunctor(src_acc, dst_acc, rng, 0, offset)); + cgh.parallel_for(cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor(src_acc, dst_acc, rng, 0, offset)); }); synchronize(); } @@ -215,7 +238,7 @@ struct SyclDevice { sycl_queue().submit([&](cl::sycl::handler &cgh) { auto src_acc= it->second.template get_access(cgh); auto dst_acc =dest_buf.template get_access(cgh); - cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), TensorSycl::internal::MemCopyFunctor(src_acc, dst_acc, rng, 0, offset)); + cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor(src_acc, dst_acc, rng, 0, offset)); }); synchronize(); } -- cgit v1.2.3