aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
commitf697df723798779bc29d9f7299bb5398767d5db0 (patch)
treec155c21ad9ef0e6269f6af83fe2f29f97a0c0e21 /unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
parent6559d09c60fb4acfc7ee5197284f576ac14926f1 (diff)
Improved support for RowMajor tensors
Misc fixes and API cleanups.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h171
1 files changed, 151 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
index d6347b054..9b14e01f4 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
@@ -24,11 +24,14 @@ template<typename PaddingDimensions, typename XprType>
struct traits<TensorPaddingOp<PaddingDimensions, XprType> > : public traits<XprType>
{
typedef typename XprType::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type Packet;
- typedef typename traits<XprType>::StorageKind StorageKind;
- typedef typename traits<XprType>::Index Index;
+ typedef traits<XprType> XprTraits;
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename XprTraits::StorageKind StorageKind;
+ typedef typename XprTraits::Index Index;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
};
template<typename PaddingDimensions, typename XprType>
@@ -88,6 +91,8 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
enum {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = true,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
@@ -99,13 +104,23 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
m_dimensions[i] += m_padding[i].first + m_padding[i].second;
}
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
- m_inputStrides[0] = 1;
- m_outputStrides[0] = 1;
- for (int i = 1; i < NumDims; ++i) {
- m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
- m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
+ if (Layout == ColMajor) {
+ m_inputStrides[0] = 1;
+ m_outputStrides[0] = 1;
+ for (int i = 1; i < NumDims; ++i) {
+ m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
+ m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
+ }
+ m_outputStrides[NumDims] = m_outputStrides[NumDims-1] * m_dimensions[NumDims-1];
+ } else {
+ m_inputStrides[NumDims - 1] = 1;
+ m_outputStrides[NumDims] = 1;
+ for (int i = NumDims - 2; i >= 0; --i) {
+ m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
+ m_outputStrides[i+1] = m_outputStrides[i+2] * m_dimensions[i+1];
+ }
+ m_outputStrides[0] = m_outputStrides[1] * m_dimensions[0];
}
- m_outputStrides[NumDims] = m_outputStrides[NumDims-1] * m_dimensions[NumDims-1];
}
typedef typename XprType::Scalar Scalar;
@@ -126,24 +141,85 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
{
eigen_assert(index < dimensions().TotalSize());
Index inputIndex = 0;
- for (int i = NumDims - 1; i > 0; --i) {
- const Index idx = index / m_outputStrides[i];
- if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
+ if (Layout == ColMajor) {
+ for (int i = NumDims - 1; i > 0; --i) {
+ const Index idx = index / m_outputStrides[i];
+ if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
+ return Scalar(0);
+ }
+ inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
+ index -= idx * m_outputStrides[i];
+ }
+ if (index < m_padding[0].first || index >= m_dimensions[0] - m_padding[0].second) {
return Scalar(0);
}
- inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
- index -= idx * m_outputStrides[i];
- }
- if (index < m_padding[0].first || index >= m_dimensions[0] - m_padding[0].second) {
- return Scalar(0);
+ inputIndex += (index - m_padding[0].first);
+ } else {
+ for (int i = 0; i < NumDims - 1; ++i) {
+ const Index idx = index / m_outputStrides[i+1];
+ if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
+ return Scalar(0);
+ }
+ inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
+ index -= idx * m_outputStrides[i+1];
+ }
+ if (index < m_padding[NumDims-1].first ||
+ index >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) {
+ return Scalar(0);
+ }
+ inputIndex += (index - m_padding[NumDims-1].first);
}
- inputIndex += (index - m_padding[0].first);
return m_impl.coeff(inputIndex);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
{
+ if (Layout == ColMajor) {
+ return packetColMajor(index);
+ }
+ return packetRowMajor(index);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<Index, NumDims>& coords) const
+ {
+ Index inputIndex;
+ if (Layout == ColMajor) {
+ const Index idx = coords[0];
+ if (idx < m_padding[0].first || idx >= m_dimensions[0] - m_padding[0].second) {
+ return Scalar(0);
+ }
+ inputIndex = idx - m_padding[0].first;
+ for (int i = 1; i < NumDims; ++i) {
+ const Index idx = coords[i];
+ if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
+ return Scalar(0);
+ }
+ inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
+ }
+ } else {
+ const Index idx = coords[NumDims-1];
+ if (idx < m_padding[NumDims-1].first || idx >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) {
+ return Scalar(0);
+ }
+ inputIndex = idx - m_padding[NumDims-1].first;
+ for (int i = NumDims - 2; i >= 0; --i) {
+ const Index idx = coords[i];
+ if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
+ return Scalar(0);
+ }
+ inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
+ }
+ }
+ return m_impl.coeff(inputIndex);
+ }
+
+ Scalar* data() const { return NULL; }
+
+ protected:
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
+ {
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
@@ -200,9 +276,64 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
return packetWithPossibleZero(initialIndex);
}
- Scalar* data() const { return NULL; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
+ {
+ const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(index+packetSize-1 < dimensions().TotalSize());
- protected:
+ const Index initialIndex = index;
+ Index inputIndex = 0;
+
+ for (int i = 0; i < NumDims - 1; ++i) {
+ const Index first = index;
+ const Index last = index + packetSize - 1;
+ const Index lastPaddedLeft = m_padding[i].first * m_outputStrides[i+1];
+ const Index firstPaddedRight = (m_dimensions[i] - m_padding[i].second) * m_outputStrides[i+1];
+ const Index lastPaddedRight = m_outputStrides[i];
+
+ if (last < lastPaddedLeft) {
+ // all the coefficient are in the padding zone.
+ return internal::pset1<PacketReturnType>(Scalar(0));
+ }
+ else if (first >= firstPaddedRight && last < lastPaddedRight) {
+ // all the coefficient are in the padding zone.
+ return internal::pset1<PacketReturnType>(Scalar(0));
+ }
+ else if (first >= lastPaddedLeft && last < firstPaddedRight) {
+ // all the coefficient are between the 2 padding zones.
+ const Index idx = index / m_outputStrides[i+1];
+ inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
+ index -= idx * m_outputStrides[i+1];
+ }
+ else {
+ // Every other case
+ return packetWithPossibleZero(initialIndex);
+ }
+ }
+
+ const Index last = index + packetSize - 1;
+ const Index first = index;
+ const Index lastPaddedLeft = m_padding[NumDims-1].first;
+ const Index firstPaddedRight = (m_dimensions[NumDims-1] - m_padding[NumDims-1].second);
+ const Index lastPaddedRight = m_outputStrides[NumDims-1];
+
+ if (last < lastPaddedLeft) {
+ // all the coefficient are in the padding zone.
+ return internal::pset1<PacketReturnType>(Scalar(0));
+ }
+ else if (first >= firstPaddedRight && last < lastPaddedRight) {
+ // all the coefficient are in the padding zone.
+ return internal::pset1<PacketReturnType>(Scalar(0));
+ }
+ else if (first >= lastPaddedLeft && last < firstPaddedRight) {
+ // all the coefficient are between the 2 padding zones.
+ inputIndex += (index - m_padding[NumDims-1].first);
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ // Every other case
+ return packetWithPossibleZero(initialIndex);
+ }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetWithPossibleZero(Index index) const
{