aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-06 16:25:16 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-06 16:25:16 -0700
commita961d72e65fc537fe571845407b4e2ee0554bd49 (patch)
treeef3f6ac79862925587a857efd00202dc612da198 /unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
parent8998f4099e20ebc80db0aba2582301cd48d31c5a (diff)
Added support for convolution and reshaping of tensors.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h119
1 files changed, 119 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
new file mode 100644
index 000000000..3e089fe1e
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -0,0 +1,119 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_MORPHING_H
+#define EIGEN_CXX11_TENSOR_TENSOR_MORPHING_H
+
+namespace Eigen {
+
+/** \class TensorReshaping
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief Tensor reshaping class.
+ *
+ *
+ */
+namespace internal {
+template<typename XprType, typename NewDimensions>
+struct traits<TensorReshapingOp<XprType, NewDimensions> >
+{
+ // Type promotion to handle the case where the types of the lhs and the rhs are different.
+ typedef typename XprType::Scalar Scalar;
+ typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef typename traits<XprType>::StorageKind StorageKind;
+ typedef typename traits<XprType>::Index Index;
+ typedef typename XprType::Nested Nested;
+ typedef typename remove_reference<Nested>::type _Nested;
+};
+
+template<typename XprType, typename NewDimensions>
+struct eval<TensorReshapingOp<XprType, NewDimensions>, Eigen::Dense>
+{
+ typedef const TensorReshapingOp<XprType, NewDimensions>& type;
+};
+
+template<typename XprType, typename NewDimensions>
+struct nested<TensorReshapingOp<XprType, NewDimensions>, 1, typename eval<TensorReshapingOp<XprType, NewDimensions> >::type>
+{
+ typedef TensorReshapingOp<XprType, NewDimensions> type;
+};
+
+} // end namespace internal
+
+
+
+template<typename XprType, typename NewDimensions>
+class TensorReshapingOp : public TensorBase<TensorReshapingOp<XprType, NewDimensions> >
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorReshapingOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorReshapingOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename Eigen::internal::nested<TensorReshapingOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorReshapingOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorReshapingOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp(const XprType& expr, const NewDimensions& dims)
+ : m_xpr(expr), m_dims(dims) {}
+
+ EIGEN_DEVICE_FUNC
+ const NewDimensions& dimensions() const { return m_dims; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename XprType::Nested>::type&
+ expression() const { return m_xpr; }
+
+ protected:
+ typename XprType::Nested m_xpr;
+ const NewDimensions m_dims;
+};
+
+
+template<typename ArgType, typename NewDimensions>
+struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
+{
+ typedef TensorReshapingOp<ArgType, NewDimensions> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<ArgType>::IsAligned,
+ PacketAccess = TensorEvaluator<ArgType>::PacketAccess,
+ };
+
+ TensorEvaluator(const XprType& op)
+ : m_impl(op.expression()), m_dimensions(op.dimensions())
+ { }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ const NewDimensions& dimensions() const { return m_dimensions; }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
+ {
+ return m_impl.coeff(index);
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
+ {
+ return m_impl.template packet<LoadMode>(index);
+ }
+
+ private:
+ NewDimensions m_dimensions;
+ TensorEvaluator<ArgType> m_impl;
+};
+
+
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_MORPHING_H