aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
commitf697df723798779bc29d9f7299bb5398767d5db0 (patch)
treec155c21ad9ef0e6269f6af83fe2f29f97a0c0e21 /unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h
parent6559d09c60fb4acfc7ee5197284f576ac14926f1 (diff)
Improved support for RowMajor tensors
Misc fixes and API cleanups.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h24
1 files changed, 17 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h
index cb14cc7f7..a9501336e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h
@@ -25,11 +25,14 @@ struct traits<TensorForcedEvalOp<XprType> >
{
// Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename XprType::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef traits<XprType> XprTraits;
+ typedef typename packet_traits<Scalar>::type Packet;
typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::Index Index;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
enum {
Flags = 0,
@@ -59,8 +62,8 @@ class TensorForcedEvalOp : public TensorBase<TensorForcedEvalOp<XprType> >
typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
- typedef typename XprType::CoeffReturnType CoeffReturnType;
- typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
+ typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
typedef typename Eigen::internal::nested<TensorForcedEvalOp>::type Nested;
typedef typename Eigen::internal::traits<TensorForcedEvalOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Index Index;
@@ -88,6 +91,7 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType>, Device>
enum {
IsAligned = true,
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
@@ -100,10 +104,16 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType>, Device>
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_impl.dimensions(); }
- EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
+ EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
m_impl.evalSubExprsIfNeeded(NULL);
- m_buffer = (Scalar*)m_device.allocate(m_impl.dimensions().TotalSize() * sizeof(Scalar));
-
+ const Index numValues = m_impl.dimensions().TotalSize();
+ m_buffer = (CoeffReturnType*)m_device.allocate(numValues * sizeof(CoeffReturnType));
+ // Should initialize the memory in case we're dealing with non POD types.
+ if (!internal::is_arithmetic<CoeffReturnType>::value) {
+ for (Index i = 0; i < numValues; ++i) {
+ new(m_buffer+i) CoeffReturnType();
+ }
+ }
typedef TensorEvalToOp<const ArgType> EvalTo;
EvalTo evalToTmp(m_buffer, m_op);
internal::TensorExecutor<const EvalTo, Device, TensorEvaluator<ArgType, Device>::PacketAccess>::run(evalToTmp, m_device);
@@ -132,7 +142,7 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType>, Device>
TensorEvaluator<ArgType, Device> m_impl;
const ArgType m_op;
const Device& m_device;
- Scalar* m_buffer;
+ CoeffReturnType* m_buffer;
};