aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h54
1 files changed, 41 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
index 649bdb308..17f10c07b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
@@ -21,8 +21,7 @@ namespace Eigen {
* Example:
* C.device(EIGEN_GPU) = A + B;
*
- * Todo: thread pools.
- * Todo: operator +=, -=, *= and so on.
+ * Todo: operator *= and /=.
*/
template <typename ExpressionType, typename DeviceType> class TensorDevice {
@@ -33,8 +32,7 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
Assign assign(m_expression, other);
- static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess;
- internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
return *this;
}
@@ -45,8 +43,18 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
Sum sum(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Sum> Assign;
Assign assign(m_expression, sum);
- static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess;
- internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
+ typedef typename OtherDerived::Scalar Scalar;
+ typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
+ Difference difference(m_expression, other);
+ typedef TensorAssignOp<ExpressionType, const Difference> Assign;
+ Assign assign(m_expression, difference);
+ internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
return *this;
}
@@ -65,8 +73,7 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool
EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
Assign assign(m_expression, other);
- static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess;
- internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
return *this;
}
@@ -77,8 +84,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool
Sum sum(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Sum> Assign;
Assign assign(m_expression, sum);
- static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess;
- internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
+ typedef typename OtherDerived::Scalar Scalar;
+ typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
+ Difference difference(m_expression, other);
+ typedef TensorAssignOp<ExpressionType, const Difference> Assign;
+ Assign assign(m_expression, difference);
+ internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
return *this;
}
@@ -99,7 +116,7 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
Assign assign(m_expression, other);
- internal::TensorExecutor<const Assign, GpuDevice, false>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
return *this;
}
@@ -110,13 +127,24 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
Sum sum(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Sum> Assign;
Assign assign(m_expression, sum);
- internal::TensorExecutor<const Assign, GpuDevice, false>::run(assign, m_device);
+ internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
+ typedef typename OtherDerived::Scalar Scalar;
+ typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
+ Difference difference(m_expression, other);
+ typedef TensorAssignOp<ExpressionType, const Difference> Assign;
+ Assign assign(m_expression, difference);
+ internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
return *this;
}
protected:
const GpuDevice& m_device;
- ExpressionType m_expression;
+ ExpressionType& m_expression;
};
#endif