From a6a628ca6b3c0d0dd6716d200ba8e7740847168a Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 19 Mar 2015 23:22:19 -0700 Subject: Added the -= operator to the device classes --- unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h | 39 +++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h index 649bdb308..7a67c56b3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h @@ -21,8 +21,7 @@ namespace Eigen { * Example: * C.device(EIGEN_GPU) = A + B; * - * Todo: thread pools. - * Todo: operator +=, -=, *= and so on. + * Todo: operator *= and /=. */ template class TensorDevice { @@ -50,6 +49,18 @@ template class TensorDevice { return *this; } + template + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp Assign; + Assign assign(m_expression, difference); + static const bool Vectorize = TensorEvaluator::PacketAccess; + internal::TensorExecutor::run(assign, m_device); + return *this; + } + protected: const DeviceType& m_device; ExpressionType& m_expression; @@ -82,6 +93,18 @@ template class TensorDevice + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp Assign; + Assign assign(m_expression, difference); + static const bool Vectorize = TensorEvaluator::PacketAccess; + internal::TensorExecutor::run(assign, m_device); + return *this; + } + protected: const ThreadPoolDevice& m_device; ExpressionType& m_expression; @@ -114,6 +137,18 @@ template class TensorDevice return *this; } + template + EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { + typedef typename OtherDerived::Scalar Scalar; + typedef TensorCwiseBinaryOp, const ExpressionType, const OtherDerived> Difference; + Difference difference(m_expression, other); + typedef TensorAssignOp Assign; + Assign assign(m_expression, difference); + static const bool Vectorize = TensorEvaluator::PacketAccess; + internal::TensorExecutor::run(assign, m_device); + return *this; + } + protected: const GpuDevice& m_device; ExpressionType m_expression; -- cgit v1.2.3