1 /* -*- C++ -*- ------------------------------------------------------------
3 Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/
5 The Configurable Math Library (CML) is distributed under the terms of the
6 Boost Software License, v1.0 (see cml/LICENSE for details).
8 *-----------------------------------------------------------------------*/
10 * @brief Multiply two matrices.
12 * @todo Does it make sense to put mat-mat multiplication as a node into the
15 * @internal This does not need to return an expression type, since the
16 * temporary generation for the matrix result is handled automatically by the
17 * compiler. i.e. when used in an expression, the result is automatically
18 * included in the expression tree as a temporary by the compiler.
24 #include <cml/et/size_checking.h>
25 #include <cml/matrix/matrix_expr.h>
27 /* This is used below to create a more meaningful compile-time error when
28 * mul is not provided with matrix or MatrixExpr arguments:
30 struct mul_expects_matrix_args_error
;
32 /* This is used below to create a more meaningful compile-time error when
33 * fixed-size arguments to mul() have the wrong size:
35 struct mul_expressions_have_wrong_size_error
;
40 /** Verify the sizes of the argument matrices for matrix multiplication.
42 * @returns a matrix_size containing the size of the resulting matrix.
44 template<typename LeftT
, typename RightT
> inline matrix_size
45 MatMulCheckedSize(const LeftT
&, const RightT
&, fixed_size_tag
)
48 ((size_t)LeftT::array_cols
== (size_t)RightT::array_rows
),
49 mul_expressions_have_wrong_size_error
);
50 return matrix_size(LeftT::array_rows
,RightT::array_cols
);
53 /** Verify the sizes of the argument matrices for matrix multiplication.
55 * @returns a matrix_size containing the size of the resulting matrix.
57 template<typename LeftT
, typename RightT
> inline matrix_size
58 MatMulCheckedSize(const LeftT
& left
, const RightT
& right
, dynamic_size_tag
)
60 matrix_size left_N
= left
.size(), right_N
= right
.size();
61 et::GetCheckedSize
<LeftT
,RightT
,dynamic_size_tag
>()
62 .equal_or_fail(left_N
.second
, right_N
.first
); /* cols,rows */
63 return matrix_size(left_N
.first
, right_N
.second
); /* rows,cols */
67 /** Matrix multiplication.
69 * Computes C = A x B (O(N^3), non-blocked algorithm).
71 template<class LeftT
, class RightT
>
72 inline typename
et::MatrixPromote
<
73 typename
et::ExprTraits
<LeftT
>::result_type
,
74 typename
et::ExprTraits
<RightT
>::result_type
76 mul(const LeftT
& left
, const RightT
& right
)
79 typedef et::ExprTraits
<LeftT
> left_traits
;
80 typedef et::ExprTraits
<RightT
> right_traits
;
81 typedef typename
left_traits::result_type left_result
;
82 typedef typename
right_traits::result_type right_result
;
84 /* First, require matrix expressions: */
86 (et::MatrixExpressions
<LeftT
,RightT
>::is_true
),
87 mul_expects_matrix_args_error
);
88 /* Note: parens are required here so that the preprocessor ignores the
92 /* Deduce size type to ensure that a run-time check is performed if
95 typedef typename
et::MatrixPromote
<
96 typename
left_traits::result_type
,
97 typename
right_traits::result_type
99 typedef typename
result_type::size_tag size_tag
;
101 /* Require that left has the same number of columns as right has rows.
102 * This automatically checks fixed-size matrices at compile time, and
103 * throws at run-time if the sizes don't match:
105 matrix_size N
= detail::MatMulCheckedSize(left
, right
, size_tag());
107 /* Create an array with the right size (resize() is a no-op for
108 * fixed-size matrices):
111 cml::et::detail::Resize(C
, N
);
113 /* XXX Specialize this for fixed-size matrices: */
114 typedef typename
result_type::value_type value_type
;
115 for(size_t i
= 0; i
< left
.rows(); ++i
) { /* rows */
116 for(size_t j
= 0; j
< right
.cols(); ++j
) { /* cols */
117 value_type
sum(left(i
,0)*right(0,j
));
118 for(size_t k
= 1; k
< right
.rows(); ++k
) {
119 sum
+= (left(i
,k
)*right(k
,j
));
128 } // namespace detail
131 /** operator*() for two matrices. */
132 template<typename E1
, class AT1
, typename L1
,
133 typename E2
, class AT2
, typename L2
,
135 inline typename
et::MatrixPromote
<
136 matrix
<E1
,AT1
,BO
,L1
>, matrix
<E2
,AT2
,BO
,L2
>
138 operator*(const matrix
<E1
,AT1
,BO
,L1
>& left
,
139 const matrix
<E2
,AT2
,BO
,L2
>& right
)
141 return detail::mul(left
,right
);
144 /** operator*() for a matrix and a MatrixXpr. */
145 template<typename E
, class AT
, typename BO
, typename L
, typename XprT
>
146 inline typename
et::MatrixPromote
<
147 matrix
<E
,AT
,BO
,L
>, typename
XprT::result_type
149 operator*(const matrix
<E
,AT
,BO
,L
>& left
,
150 const et::MatrixXpr
<XprT
>& right
)
152 /* Generate a temporary, and compute the right-hand expression: */
153 typedef typename
et::MatrixXpr
<XprT
>::temporary_type expr_tmp
;
155 cml::et::detail::Resize(tmp
,right
.rows(),right
.cols());
158 return detail::mul(left
,tmp
);
161 /** operator*() for a MatrixXpr and a matrix. */
162 template<typename XprT
, typename E
, class AT
, typename BO
, typename L
>
163 inline typename
et::MatrixPromote
<
164 typename
XprT::result_type
, matrix
<E
,AT
,BO
,L
>
166 operator*(const et::MatrixXpr
<XprT
>& left
,
167 const matrix
<E
,AT
,BO
,L
>& right
)
169 /* Generate a temporary, and compute the left-hand expression: */
170 typedef typename
et::MatrixXpr
<XprT
>::temporary_type expr_tmp
;
172 cml::et::detail::Resize(tmp
,left
.rows(),left
.cols());
175 return detail::mul(tmp
,right
);
178 /** operator*() for two MatrixXpr's. */
179 template<typename XprT1
, typename XprT2
>
180 inline typename
et::MatrixPromote
<
181 typename
XprT1::result_type
, typename
XprT2::result_type
183 operator*(const et::MatrixXpr
<XprT1
>& left
,
184 const et::MatrixXpr
<XprT2
>& right
)
186 /* Generate temporaries and compute expressions: */
187 typedef typename
et::MatrixXpr
<XprT1
>::temporary_type left_tmp
;
189 cml::et::detail::Resize(ltmp
,left
.rows(),left
.cols());
192 typedef typename
et::MatrixXpr
<XprT2
>::temporary_type right_tmp
;
194 cml::et::detail::Resize(rtmp
,right
.rows(),right
.cols());
197 return detail::mul(ltmp
,rtmp
);
204 // -------------------------------------------------------------------------