diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-30 15:36:29 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-30 15:36:29 -0700 |
commit | f1f480b116913b9c90fce0626f1643eb9f021003 (patch) | |
tree | b45dbfadc6cc8540d2fbd62904747ead0020b38b /unsupported/Eigen/CXX11/src/Tensor | |
parent | dc31fcb9ba064f2124827aefb2ca5857327c005e (diff) |
Added support for user defined custom tensor op.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
4 files changed, 328 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index b5bdf9d8e..1c31b9d28 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -481,6 +481,18 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorStridingOp<const Strides, const Derived>(derived(), strides); } + // Added support for custom unary and binary operations + template <typename CustomUnaryFunc> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCustomUnaryOp<const CustomUnaryFunc, const Derived> customOp(const CustomUnaryFunc& op) const { + return TensorCustomUnaryOp<const CustomUnaryFunc, const Derived>(derived(), op); + } + template <typename OtherDerived, typename CustomBinaryFunc> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCustomBinaryOp<const CustomBinaryFunc, const Derived, const OtherDerived> customOp(const OtherDerived& other, const CustomBinaryFunc& op) const { + return TensorCustomBinaryOp<const CustomBinaryFunc, const Derived, const OtherDerived>(derived(), other, op); + } + // Force the evaluation of the expression. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorForcedEvalOp<const Derived> eval() const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h b/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h new file mode 100644 index 000000000..0157f6fab --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h @@ -0,0 +1,310 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H +#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H + +namespace Eigen { + +/** \class TensorCustomUnaryOp + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor custom class. + * + * + */ +namespace internal { +template<typename CustomUnaryFunc, typename XprType> +struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > +{ + typedef typename XprType::Scalar Scalar; + typedef typename packet_traits<Scalar>::type Packet; + typedef typename XprType::StorageKind StorageKind; + typedef typename XprType::Index Index; + typedef typename XprType::Nested Nested; + typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = traits<XprType>::NumDimensions; + static const int Layout = traits<XprType>::Layout; +}; + +template<typename CustomUnaryFunc, typename XprType> +struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> +{ + typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type; +}; + +template<typename CustomUnaryFunc, typename XprType> +struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > +{ + typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type; +}; + +} // end namespace internal + + + +template<typename CustomUnaryFunc, typename XprType> +class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> +{ + public: + typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar; + typedef typename internal::traits<TensorCustomUnaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef typename internal::nested<TensorCustomUnaryOp>::type Nested; + typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind; + typedef typename internal::traits<TensorCustomUnaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) + : m_expr(expr), m_func(func) {} + + EIGEN_DEVICE_FUNC + const CustomUnaryFunc& func() const { return m_func; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename XprType::Nested>::type& + expression() const { return m_expr; } + + protected: + typename XprType::Nested m_expr; + const CustomUnaryFunc m_func; +}; + + +// Eval as rvalue +template<typename CustomUnaryFunc, typename XprType, typename Device> +struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> +{ + typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType; + typedef typename internal::traits<ArgType>::Index Index; + static const int NumDims = internal::traits<ArgType>::NumDimensions; + typedef DSizes<Index, NumDims> Dimensions; + typedef + typename internal::remove_const<typename ArgType::Scalar>::type Scalar; + + enum { + IsAligned = false, + PacketAccess = (internal::packet_traits<Scalar>::size > 1), + BlockAccess = false, + Layout = TensorEvaluator<XprType, Device>::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) + : m_op(op), m_device(device), m_result(NULL) + { + m_dimensions = op.func().dimensions(op.expression()); + } + + typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (data) { + evalTo(data); + return false; + } else { + m_result = static_cast<CoeffReturnType*>( + m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + evalTo(m_result); + return true; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + if (m_result != NULL) { + m_device.deallocate(m_result); + m_result = NULL; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_result[index]; + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { + return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); + } + + EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } + + protected: + EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { + TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result( + data, m_dimensions); + m_op.func().eval(m_op.expression(), result, m_device); + } + + Dimensions m_dimensions; + const ArgType m_op; + const Device& m_device; + CoeffReturnType* m_result; +}; + + + +/** \class TensorCustomBinaryOp + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor custom class. + * + * + */ +namespace internal { +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> +struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > +{ + typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, + typename RhsXprType::Scalar>::ret Scalar; + typedef typename packet_traits<Scalar>::type Packet; + typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, + typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; + typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType, + typename RhsXprType::PacketReturnType>::ret PacketReturnType; + typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, + typename traits<RhsXprType>::StorageKind>::ret StorageKind; + typedef typename promote_index_type<typename traits<LhsXprType>::Index, + typename traits<RhsXprType>::Index>::type Index; + typedef typename LhsXprType::Nested LhsNested; + typedef typename RhsXprType::Nested RhsNested; + typedef typename remove_reference<LhsNested>::type _LhsNested; + typedef typename remove_reference<RhsNested>::type _RhsNested; + static const int NumDimensions = traits<LhsXprType>::NumDimensions; + static const int Layout = traits<LhsXprType>::Layout; +}; + +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> +struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> +{ + typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type; +}; + +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> +struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > +{ + typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type; +}; + +} // end namespace internal + + + +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> +class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> +{ + public: + typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar; + typedef typename internal::traits<TensorCustomBinaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType; + typedef typename internal::traits<TensorCustomBinaryOp>::PacketReturnType PacketReturnType; + typedef typename internal::nested<TensorCustomBinaryOp>::type Nested; + typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind; + typedef typename internal::traits<TensorCustomBinaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) + + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} + + EIGEN_DEVICE_FUNC + const CustomBinaryFunc& func() const { return m_func; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename LhsXprType::Nested>::type& + lhsExpression() const { return m_lhs_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename RhsXprType::Nested>::type& + rhsExpression() const { return m_rhs_xpr; } + + protected: + typename LhsXprType::Nested m_lhs_xpr; + typename RhsXprType::Nested m_rhs_xpr; + const CustomBinaryFunc m_func; +}; + + +// Eval as rvalue +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device> +struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> +{ + typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType; + typedef typename internal::traits<XprType>::Index Index; + static const int NumDims = internal::traits<XprType>::NumDimensions; + typedef DSizes<Index, NumDims> Dimensions; + typedef typename XprType::Scalar Scalar; + + enum { + IsAligned = false, + PacketAccess = (internal::packet_traits<Scalar>::size > 1), + BlockAccess = false, + Layout = TensorEvaluator<LhsXprType, Device>::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_op(op), m_device(device), m_result(NULL) + { + m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); + } + + typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (data) { + evalTo(data); + return false; + } else { + m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + evalTo(m_result); + return true; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + if (m_result != NULL) { + m_device.deallocate(m_result); + m_result = NULL; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_result[index]; + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { + return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); + } + + EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } + + protected: + EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { + TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions); + m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); + } + + Dimensions m_dimensions; + const XprType m_op; + const Device& m_device; + CoeffReturnType* m_result; +}; + + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index a3fb26acc..7df8d1453 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -42,6 +42,9 @@ template<typename Strides, typename XprType> class TensorStridingOp; template<typename Generator, typename XprType> class TensorGeneratorOp; template<typename LeftXprType, typename RightXprType> class TensorAssignOp; +template<typename CustomUnaryFunc, typename XprType> class TensorCustomUnaryOp; +template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> class TensorCustomBinaryOp; + template<typename XprType> class TensorEvalToOp; template<typename XprType> class TensorForcedEvalOp; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 757d6cb38..2e6b93bfb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -156,9 +156,9 @@ struct eval<const TensorRef<PlainObjectType>, Eigen::Dense> }; // TODO nested<> does not exist anymore in Eigen/Core, and it thus has to be removed in favor of ref_selector. -template<typename T, int n=1, typename PlainObject = void> struct nested -{ - typedef typename ref_selector<T>::type type; +template<typename T, int n=1, typename PlainObject = void> struct nested +{ + typedef typename ref_selector<T>::type type; }; template <typename Scalar_, std::size_t NumIndices_, int Options_, typename IndexType_> |