From 6126ad801f89cb88d46ed0dae0d8e9448652c506 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 31 Jul 2013 15:30:50 +0200 Subject: Extend support for nvcc to Array objects and wrappers --- Eigen/src/Core/ArrayWrapper.h | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) (limited to 'Eigen/src/Core/ArrayWrapper.h') diff --git a/Eigen/src/Core/ArrayWrapper.h b/Eigen/src/Core/ArrayWrapper.h index a791bc358..21830745b 100644 --- a/Eigen/src/Core/ArrayWrapper.h +++ b/Eigen/src/Core/ArrayWrapper.h @@ -48,41 +48,54 @@ class ArrayWrapper : public ArrayBase > typedef typename internal::nested::type NestedExpressionType; + EIGEN_DEVICE_FUNC inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {} + EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); } + EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); } + EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); } + EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); } + EIGEN_DEVICE_FUNC inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return m_expression.data(); } + EIGEN_DEVICE_FUNC inline CoeffReturnType coeff(Index rowId, Index colId) const { return m_expression.coeff(rowId, colId); } + EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index rowId, Index colId) { return m_expression.const_cast_derived().coeffRef(rowId, colId); } + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index rowId, Index colId) const { return m_expression.const_cast_derived().coeffRef(rowId, colId); } + EIGEN_DEVICE_FUNC inline CoeffReturnType coeff(Index index) const { return m_expression.coeff(index); } + EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index index) { return m_expression.const_cast_derived().coeffRef(index); } + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { return m_expression.const_cast_derived().coeffRef(index); @@ -113,9 +126,11 @@ class ArrayWrapper : public ArrayBase > } template + EIGEN_DEVICE_FUNC inline void evalTo(Dest& dst) const { dst = m_expression; } const typename internal::remove_all::type& + EIGEN_DEVICE_FUNC nestedExpression() const { return m_expression; @@ -123,9 +138,11 @@ class ArrayWrapper : public ArrayBase > /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index) */ + EIGEN_DEVICE_FUNC void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index,Index)*/ + EIGEN_DEVICE_FUNC void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } protected: @@ -168,41 +185,54 @@ class MatrixWrapper : public MatrixBase > typedef typename internal::nested::type NestedExpressionType; + EIGEN_DEVICE_FUNC inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {} + EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); } + EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); } + EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); } + EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); } + EIGEN_DEVICE_FUNC inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return m_expression.data(); } + EIGEN_DEVICE_FUNC inline CoeffReturnType coeff(Index rowId, Index colId) const { return m_expression.coeff(rowId, colId); } + EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index rowId, Index colId) { return m_expression.const_cast_derived().coeffRef(rowId, colId); } + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index rowId, Index colId) const { return m_expression.derived().coeffRef(rowId, colId); } + EIGEN_DEVICE_FUNC inline CoeffReturnType coeff(Index index) const { return m_expression.coeff(index); } + EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index index) { return m_expression.const_cast_derived().coeffRef(index); } + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { return m_expression.const_cast_derived().coeffRef(index); @@ -232,6 +262,7 @@ class MatrixWrapper : public MatrixBase > m_expression.const_cast_derived().template writePacket(index, val); } + EIGEN_DEVICE_FUNC const typename internal::remove_all::type& nestedExpression() const { @@ -240,9 +271,11 @@ class MatrixWrapper : public MatrixBase > /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index) */ + EIGEN_DEVICE_FUNC void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index,Index)*/ + EIGEN_DEVICE_FUNC void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } protected: -- cgit v1.2.3