aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-17 14:09:37 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-17 14:09:37 -0700
commitc95aacab90e9d8bb9f9e082395b3b843a530fa41 (patch)
tree1bf812626899fb13c0e1cfc350ecbac95bd135c2 /unsupported/Eigen/CXX11
parent038b55464b1d43612b88789f26006163ca638928 (diff)
Fix TensorContractionOp evaluators for GPU and SYCL
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h10
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h13
2 files changed, 13 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
index b9956cd43..6d3aa24c8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
@@ -505,9 +505,9 @@ template<typename Scalar, typename Index, typename LhsMapper,
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(512, 1)
-#else
+#else
__launch_bounds__(512)
-#endif
+#endif
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) {
@@ -698,7 +698,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
#undef prefetch_lhs
#undef add_vals
-
+
Index horiz_base = threadIdx.y*4+base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 4; i++) {
@@ -1137,7 +1137,7 @@ template<typename Index, typename LhsMapper,
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1)
-#else
+#else
__launch_bounds__(256)
#endif
EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
@@ -1184,7 +1184,7 @@ template<typename Index, typename LhsMapper,
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1)
-#else
+#else
__launch_bounds__(256)
#endif
EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
index e6840bc87..35f931c53 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
@@ -23,15 +23,18 @@
namespace Eigen {
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
-template<typename Indices, typename LeftArgType, typename RightArgType>
-struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> :
- public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > {
+template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
+struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> :
+ public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > {
+
+ static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
+ "SYCL tensor contraction does not support output kernels.");
typedef const Eigen::SyclDevice Device;
- typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
+ typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
- typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;