diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index 2c7ba961c..363df876c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -37,6 +37,8 @@ struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > static const int NumDimensions = traits<LhsXprType>::NumDimensions; static const int Layout = traits<LhsXprType>::Layout; enum { Flags = 0 }; + typedef typename conditional<::Eigen::internal::Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, + typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType; }; template<typename Axis, typename LhsXprType, typename RhsXprType> @@ -275,7 +277,7 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy TensorOpCost(0, 0, compute_cost); } - EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } + EIGEN_DEVICE_FUNC typename Eigen::internal::traits<XprType>::PointerType data() const { return NULL; } /// required by sycl in order to extract the accessor const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; } /// required by sycl in order to extract the accessor |