aboutsummaryrefslogtreecommitdiffhomepage
path: root/tvmet-1.7.1/include/tvmet/xpr/MMProduct.h
blob: 7fe4f746a4b2412a205f000f9df44506e31f3ad5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*
 * Tiny Vector Matrix Library
 * Dense Vector Matrix Libary of Tiny size using Expression Templates
 *
 * Copyright (C) 2001 - 2003 Olaf Petzold <opetzold@users.sourceforge.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * $Id: MMProduct.h,v 1.20 2004/09/16 09:14:18 opetzold Exp $
 */

#ifndef TVMET_XPR_MMPRODUCT_H
#define TVMET_XPR_MMPRODUCT_H

#include <cassert>

#include <tvmet/meta/Gemm.h>
#include <tvmet/loop/Gemm.h>

namespace tvmet {


/**
 * \class XprMMProduct MMProduct.h "tvmet/xpr/MMProduct.h"
 * \brief Expression for matrix-matrix product.
 *        Using formula:
 *        \f[
 *        M_1\,M_2
 *        \f]
 * \note The Rows2 has to be  equal to Cols1.
 */
template<class E1, int Rows1, int Cols1,
	 class E2, int Cols2>
class XprMMProduct
  : public TvmetBase< XprMMProduct<E1, Rows1, Cols1, E2, Cols2> >
{
private:
  XprMMProduct();
  XprMMProduct& operator=(const XprMMProduct&);

public:
  typedef typename PromoteTraits<
    typename E1::value_type,
    typename E2::value_type
  >::value_type							value_type;

public:
  /** Complexity counter. */
  enum {
    ops_lhs   = E1::ops,
    ops_rhs   = E2::ops,
    M         = Rows1 * Cols1 * Cols2,
    N         = Rows1 * (Cols1 - 1) * Cols2,
    ops_plus  = M * NumericTraits<value_type>::ops_plus,
    ops_muls  = N * NumericTraits<value_type>::ops_muls,
    ops       = ops_plus + ops_muls,
    use_meta  = Rows1*Cols2 < TVMET_COMPLEXITY_MM_TRIGGER ? true : false
  };

public:
  /** Constructor. */
  explicit XprMMProduct(const E1& lhs, const E2& rhs)
    : m_lhs(lhs), m_rhs(rhs)
  { }

  /** Copy Constructor. Not explicit! */
#if defined(TVMET_OPTIMIZE_XPR_MANUAL_CCTOR)
  XprMMProduct(const XprMMProduct& e)
    : m_lhs(e.m_lhs), m_rhs(e.m_rhs)
  { }
#endif

private:
  /** Wrapper for meta gemm. */
  static inline
  value_type do_gemm(dispatch<true>, const E1& lhs, const E2& rhs, int i, int j) {
    return meta::gemm<Rows1, Cols1,
                      Cols2,
                      0>::prod(lhs, rhs, i, j);
  }

  /** Wrapper for loop gemm. */
  static inline
  value_type do_gemm(dispatch<false>, const E1& lhs, const E2& rhs, int i, int j) {
    return loop::gemm<Rows1, Cols1, Cols2>::prod(lhs, rhs, i, j);
  }

public:
  /** index operator for arrays/matrices */
  value_type operator()(int i, int j) const {
    assert((i < Rows1) && (j < Cols2));
    return do_gemm(dispatch<use_meta>(), m_lhs, m_rhs, i, j);
  }

public: // debugging Xpr parse tree
  void print_xpr(std::ostream& os, int l=0) const {
    os << IndentLevel(l++)
       << "XprMMProduct["
       << (use_meta ? "M" :  "L") << ", O=" << ops
       << ", (O1=" << ops_lhs << ", O2=" << ops_rhs << ")]<"
       << std::endl;
    m_lhs.print_xpr(os, l);
    os << IndentLevel(l)
       << "R1=" << Rows1 << ", C1=" << Cols1 << ",\n";
    m_rhs.print_xpr(os, l);
    os << IndentLevel(l)
       << "C2=" << Cols2 << ",\n";
    os << IndentLevel(--l)
       << ">," << std::endl;
  }

private:
  const E1		 				m_lhs;
  const E2		 				m_rhs;
};


} // namespace tvmet

#endif // TVMET_XPR_MMPRODUCT_H

// Local Variables:
// mode:C++
// End: