From f9ad25e4d8453c4265a5fd6d4962a76a386564df Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 22 Mar 2016 09:30:23 -0700 Subject: Fixed contractions of 16 bit floats --- unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h index a4a06ab5f..dbff660a9 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -20,7 +20,7 @@ template __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, - const OutputMapper output, volatile Scalar* lhs_shmem, volatile Scalar* rhs_shmem, + const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem, const Index m_size, const Index n_size, const Index k_size) { const Index m_block_idx = blockIdx.x; @@ -319,8 +319,8 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, Scalar rrow(7); // Now x corresponds to k, y to m, and z to n - const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; - const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; + const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; + const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] @@ -503,8 +503,8 @@ __launch_bounds__(512) EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size, const Index n_size, const Index k_size) { - __shared__ volatile Scalar lhs_shmem[72 * 64]; - __shared__ volatile Scalar rhs_shmem[72 * 64]; + __shared__ Scalar lhs_shmem[72 * 64]; + __shared__ Scalar rhs_shmem[72 * 64]; const Index m_block_idx = blockIdx.x; const Index n_block_idx = blockIdx.y; -- cgit v1.2.3