aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Swap.h
blob: 8fd94b3c7af398c5a873d2a73c796bfd2255bcd9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_SWAP_H
#define EIGEN_SWAP_H

namespace Eigen { 

// #ifndef EIGEN_TEST_EVALUATORS
  
/** \class SwapWrapper
  * \ingroup Core_Module
  *
  * \internal
  *
  * \brief Internal helper class for swapping two expressions
  */
namespace internal {
template<typename ExpressionType>
struct traits<SwapWrapper<ExpressionType> > : traits<ExpressionType> {};
}

template<typename ExpressionType> class SwapWrapper
  : public internal::dense_xpr_base<SwapWrapper<ExpressionType> >::type
{
  public:

    typedef typename internal::dense_xpr_base<SwapWrapper>::type Base;
    EIGEN_DENSE_PUBLIC_INTERFACE(SwapWrapper)
    typedef typename internal::packet_traits<Scalar>::type Packet;

    EIGEN_DEVICE_FUNC
    inline SwapWrapper(ExpressionType& xpr) : m_expression(xpr) {}

    EIGEN_DEVICE_FUNC
    inline Index rows() const { return m_expression.rows(); }
    EIGEN_DEVICE_FUNC
    inline Index cols() const { return m_expression.cols(); }
    EIGEN_DEVICE_FUNC
    inline Index outerStride() const { return m_expression.outerStride(); }
    EIGEN_DEVICE_FUNC
    inline Index innerStride() const { return m_expression.innerStride(); }
    
    typedef typename internal::conditional<
                       internal::is_lvalue<ExpressionType>::value,
                       Scalar,
                       const Scalar
                     >::type ScalarWithConstIfNotLvalue;
                     
    EIGEN_DEVICE_FUNC
    inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); }
    EIGEN_DEVICE_FUNC
    inline const Scalar* data() const { return m_expression.data(); }

    EIGEN_DEVICE_FUNC
    inline Scalar& coeffRef(Index rowId, Index colId)
    {
      return m_expression.const_cast_derived().coeffRef(rowId, colId);
    }

    EIGEN_DEVICE_FUNC
    inline Scalar& coeffRef(Index index)
    {
      return m_expression.const_cast_derived().coeffRef(index);
    }

    EIGEN_DEVICE_FUNC
    inline Scalar& coeffRef(Index rowId, Index colId) const
    {
      return m_expression.coeffRef(rowId, colId);
    }

    EIGEN_DEVICE_FUNC
    inline Scalar& coeffRef(Index index) const
    {
      return m_expression.coeffRef(index);
    }

    template<typename OtherDerived>
    EIGEN_DEVICE_FUNC
    void copyCoeff(Index rowId, Index colId, const DenseBase<OtherDerived>& other)
    {
      OtherDerived& _other = other.const_cast_derived();
      eigen_internal_assert(rowId >= 0 && rowId < rows()
                         && colId >= 0 && colId < cols());
      Scalar tmp = m_expression.coeff(rowId, colId);
      m_expression.coeffRef(rowId, colId) = _other.coeff(rowId, colId);
      _other.coeffRef(rowId, colId) = tmp;
    }

    template<typename OtherDerived>
    EIGEN_DEVICE_FUNC
    void copyCoeff(Index index, const DenseBase<OtherDerived>& other)
    {
      OtherDerived& _other = other.const_cast_derived();
      eigen_internal_assert(index >= 0 && index < m_expression.size());
      Scalar tmp = m_expression.coeff(index);
      m_expression.coeffRef(index) = _other.coeff(index);
      _other.coeffRef(index) = tmp;
    }

    template<typename OtherDerived, int StoreMode, int LoadMode>
    void copyPacket(Index rowId, Index colId, const DenseBase<OtherDerived>& other)
    {
      OtherDerived& _other = other.const_cast_derived();
      eigen_internal_assert(rowId >= 0 && rowId < rows()
                        && colId >= 0 && colId < cols());
      Packet tmp = m_expression.template packet<StoreMode>(rowId, colId);
      m_expression.template writePacket<StoreMode>(rowId, colId,
        _other.template packet<LoadMode>(rowId, colId)
      );
      _other.template writePacket<LoadMode>(rowId, colId, tmp);
    }

    template<typename OtherDerived, int StoreMode, int LoadMode>
    void copyPacket(Index index, const DenseBase<OtherDerived>& other)
    {
      OtherDerived& _other = other.const_cast_derived();
      eigen_internal_assert(index >= 0 && index < m_expression.size());
      Packet tmp = m_expression.template packet<StoreMode>(index);
      m_expression.template writePacket<StoreMode>(index,
        _other.template packet<LoadMode>(index)
      );
      _other.template writePacket<LoadMode>(index, tmp);
    }

    EIGEN_DEVICE_FUNC
    ExpressionType& expression() const { return m_expression; }

  protected:
    ExpressionType& m_expression;
};

// #endif

#ifdef EIGEN_TEST_EVALUATORS

namespace internal {

// Overload default assignPacket behavior for swapping them
template<typename DstEvaluatorTypeT, typename SrcEvaluatorTypeT>
class dense_swap_kernel : public generic_dense_assignment_kernel<DstEvaluatorTypeT, SrcEvaluatorTypeT, swap_assign_op<typename DstEvaluatorTypeT::Scalar> >
{
  typedef generic_dense_assignment_kernel<DstEvaluatorTypeT, SrcEvaluatorTypeT, swap_assign_op<typename DstEvaluatorTypeT::Scalar> > Base;
  typedef typename DstEvaluatorTypeT::PacketScalar PacketScalar;
  using Base::m_dst;
  using Base::m_src;
  using Base::m_functor;
  
public:
  typedef typename Base::Scalar Scalar;
  typedef typename Base::Index Index;
  typedef typename Base::DstXprType DstXprType;
  
  dense_swap_kernel(DstEvaluatorTypeT &dst, const SrcEvaluatorTypeT &src, DstXprType& dstExpr)
    : Base(dst, src, swap_assign_op<Scalar>(), dstExpr)
  {}
  
  template<int StoreMode, int LoadMode>
  void assignPacket(Index row, Index col)
  {
    m_functor.template swapPacket<StoreMode,LoadMode,PacketScalar>(&m_dst.coeffRef(row,col), &const_cast<SrcEvaluatorTypeT&>(m_src).coeffRef(row,col));
  }
  
  template<int StoreMode, int LoadMode>
  void assignPacket(Index index)
  {
    m_functor.template swapPacket<StoreMode,LoadMode,PacketScalar>(&m_dst.coeffRef(index), &const_cast<SrcEvaluatorTypeT&>(m_src).coeffRef(index));
  }
  
  // TODO find a simple way not to have to copy/paste this function from generic_dense_assignment_kernel, by simple I mean no CRTP (Gael)
  template<int StoreMode, int LoadMode>
  void assignPacketByOuterInner(Index outer, Index inner)
  {
    Index row = Base::rowIndexByOuterInner(outer, inner); 
    Index col = Base::colIndexByOuterInner(outer, inner);
    assignPacket<StoreMode,LoadMode>(row, col);
  }
};
  
template<typename DstXprType, typename SrcXprType>
void call_dense_swap_loop(const DstXprType& dst, const SrcXprType& src)
{
  // TODO there is too much redundancy with call_dense_assignment_loop
  
  eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
  
  typedef typename evaluator<DstXprType>::type DstEvaluatorType;
  typedef typename evaluator<SrcXprType>::type SrcEvaluatorType;

  DstEvaluatorType dstEvaluator(dst);
  SrcEvaluatorType srcEvaluator(src);
    
  typedef dense_swap_kernel<DstEvaluatorType,SrcEvaluatorType> Kernel;
  Kernel kernel(dstEvaluator, srcEvaluator, dst.const_cast_derived());
  
  dense_assignment_loop<Kernel>::run(kernel);
}

} // namespace internal

#endif

} // end namespace Eigen

#endif // EIGEN_SWAP_H