diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-17 14:09:37 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-17 14:09:37 -0700 |
commit | c95aacab90e9d8bb9f9e082395b3b843a530fa41 (patch) | |
tree | 1bf812626899fb13c0e1cfc350ecbac95bd135c2 /unsupported | |
parent | 038b55464b1d43612b88789f26006163ca638928 (diff) |
Fix TensorContractionOp evaluators for GPU and SYCL
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h | 10 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 13 |
2 files changed, 13 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h index b9956cd43..6d3aa24c8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h @@ -505,9 +505,9 @@ template<typename Scalar, typename Index, typename LhsMapper, __global__ void #if defined(EIGEN_HIPCC) __launch_bounds__(512, 1) -#else +#else __launch_bounds__(512) -#endif +#endif EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size, const Index n_size, const Index k_size) { @@ -698,7 +698,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh #undef prefetch_lhs #undef add_vals - + Index horiz_base = threadIdx.y*4+base_n; if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { for (int i = 0; i < 4; i++) { @@ -1137,7 +1137,7 @@ template<typename Index, typename LhsMapper, __global__ void #if defined(EIGEN_HIPCC) __launch_bounds__(256, 1) -#else +#else __launch_bounds__(256) #endif EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, @@ -1184,7 +1184,7 @@ template<typename Index, typename LhsMapper, __global__ void #if defined(EIGEN_HIPCC) __launch_bounds__(256, 1) -#else +#else __launch_bounds__(256) #endif EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index e6840bc87..35f931c53 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -23,15 +23,18 @@ namespace Eigen { template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels; -template<typename Indices, typename LeftArgType, typename RightArgType> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> : - public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > { +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> : + public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > { + + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "SYCL tensor contraction does not support output kernels."); typedef const Eigen::SyclDevice Device; - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; |