diff options
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 7 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h | 45 |
2 files changed, 31 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 5679e58cf..66772a3ad 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -643,7 +643,12 @@ class TensorBase<Derived, ReadOnlyAccessors> template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp<const PaddingDimensions, const Derived> pad(const PaddingDimensions& padding) const { - return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding); + return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding, internal::scalar_cast_op<int, Scalar>()(0)); + } + template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorPaddingOp<const PaddingDimensions, const Derived> + pad(const PaddingDimensions& padding, const Scalar padding_value) const { + return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding, padding_value); } template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp<const Shuffle, const Derived> diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h index c3f25f0df..eaaf4dc86 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h @@ -16,7 +16,7 @@ namespace Eigen { * \ingroup CXX11_Tensor_Module * * \brief Tensor padding class. - * At the moment only 0-padding is supported. + * At the moment only padding with a constant value is supported. * */ namespace internal { @@ -63,11 +63,13 @@ class TensorPaddingOp : public TensorBase<TensorPaddingOp<PaddingDimensions, Xpr typedef typename Eigen::internal::traits<TensorPaddingOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorPaddingOp>::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPaddingOp(const XprType& expr, const PaddingDimensions& padding_dims) - : m_xpr(expr), m_padding_dims(padding_dims) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPaddingOp(const XprType& expr, const PaddingDimensions& padding_dims, const Scalar padding_value) + : m_xpr(expr), m_padding_dims(padding_dims), m_padding_value(padding_value) {} EIGEN_DEVICE_FUNC const PaddingDimensions& padding() const { return m_padding_dims; } + EIGEN_DEVICE_FUNC + Scalar padding_value() const { return m_padding_value; } EIGEN_DEVICE_FUNC const typename internal::remove_all<typename XprType::Nested>::type& @@ -76,6 +78,7 @@ class TensorPaddingOp : public TensorBase<TensorPaddingOp<PaddingDimensions, Xpr protected: typename XprType::Nested m_xpr; const PaddingDimensions m_padding_dims; + const Scalar m_padding_value; }; @@ -97,7 +100,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) - : m_impl(op.expression(), device), m_padding(op.padding()) + : m_impl(op.expression(), device), m_padding(op.padding()), m_paddingValue(op.padding_value()) { // The padding op doesn't change the rank of the tensor. Directly padding a scalar would lead // to a vector, which doesn't make sense. Instead one should reshape the scalar into a vector @@ -151,27 +154,27 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device for (int i = NumDims - 1; i > 0; --i) { const Index idx = index / m_outputStrides[i]; if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; index -= idx * m_outputStrides[i]; } if (index < m_padding[0].first || index >= m_dimensions[0] - m_padding[0].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (index - m_padding[0].first); } else { for (int i = 0; i < NumDims - 1; ++i) { const Index idx = index / m_outputStrides[i+1]; if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; index -= idx * m_outputStrides[i+1]; } if (index < m_padding[NumDims-1].first || index >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (index - m_padding[NumDims-1].first); } @@ -194,14 +197,14 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device { const Index idx = coords[0]; if (idx < m_padding[0].first || idx >= m_dimensions[0] - m_padding[0].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex = idx - m_padding[0].first; } for (int i = 1; i < NumDims; ++i) { const Index idx = coords[i]; if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; } @@ -209,14 +212,14 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device { const Index idx = coords[NumDims-1]; if (idx < m_padding[NumDims-1].first || idx >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex = idx - m_padding[NumDims-1].first; } for (int i = NumDims - 2; i >= 0; --i) { const Index idx = coords[i]; if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { - return internal::scalar_cast_op<int, Scalar>()(0); + return m_paddingValue; } inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; } @@ -245,11 +248,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device if (last < lastPaddedLeft) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= firstPaddedRight && last < lastPaddedRight) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= lastPaddedLeft && last < firstPaddedRight) { // all the coefficient are between the 2 padding zones. @@ -271,11 +274,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device if (last < lastPaddedLeft) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= firstPaddedRight && last < lastPaddedRight) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= lastPaddedLeft && last < firstPaddedRight) { // all the coefficient are between the 2 padding zones. @@ -304,11 +307,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device if (last < lastPaddedLeft) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= firstPaddedRight && last < lastPaddedRight) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= lastPaddedLeft && last < firstPaddedRight) { // all the coefficient are between the 2 padding zones. @@ -330,11 +333,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device if (last < lastPaddedLeft) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= firstPaddedRight && last < lastPaddedRight) { // all the coefficient are in the padding zone. - return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); + return internal::pset1<PacketReturnType>(m_paddingValue); } else if (first >= lastPaddedLeft && last < firstPaddedRight) { // all the coefficient are between the 2 padding zones. @@ -361,6 +364,8 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device array<Index, NumDims> m_inputStrides; TensorEvaluator<ArgType, Device> m_impl; PaddingDimensions m_padding; + + Scalar m_paddingValue; }; |