diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h index 649bdb308..17f10c07b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h @@ -21,8 +21,7 @@ namespace Eigen { * Example: * C.device(EIGEN_GPU) = A + B; * - * Todo: thread pools. - * Todo: operator +=, -=, *= and so on. + * Todo: operator *= and /=. */ template <typename ExpressionType, typename DeviceType> class TensorDevice { @@ -33,8 +32,7 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice { EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) { typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; Assign assign(m_expression, other); - static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess; - internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device); + internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); return *this; } @@ -45,8 +43,18 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice { Sum sum(m_expression, other); typedef TensorAssignOp<ExpressionType, const Sum> Assign; Assign assign(m_expression, sum); - static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess; - internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device); + internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); + return *this; + } + + template<typename OtherDerived> + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp<ExpressionType, const Difference> Assign; + Assign assign(m_expression, difference); + internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); return *this; } @@ -65,8 +73,7 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) { typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; Assign assign(m_expression, other); - static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess; - internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device); + internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device); return *this; } @@ -77,8 +84,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool Sum sum(m_expression, other); typedef TensorAssignOp<ExpressionType, const Sum> Assign; Assign assign(m_expression, sum); - static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess; - internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device); + internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device); + return *this; + } + + template<typename OtherDerived> + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp<ExpressionType, const Difference> Assign; + Assign assign(m_expression, difference); + internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device); return *this; } @@ -99,7 +116,7 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice> EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) { typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; Assign assign(m_expression, other); - internal::TensorExecutor<const Assign, GpuDevice, false>::run(assign, m_device); + internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device); return *this; } @@ -110,13 +127,24 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice> Sum sum(m_expression, other); typedef TensorAssignOp<ExpressionType, const Sum> Assign; Assign assign(m_expression, sum); - internal::TensorExecutor<const Assign, GpuDevice, false>::run(assign, m_device); + internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device); + return *this; + } + + template<typename OtherDerived> + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp<ExpressionType, const Difference> Assign; + Assign assign(m_expression, difference); + internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device); return *this; } protected: const GpuDevice& m_device; - ExpressionType m_expression; + ExpressionType& m_expression; }; #endif |