1 /* -*- C++ -*- ------------------------------------------------------------
3 Copyright (c) 2007 Jesse Anders and Demian Nave http://cmldev.net/
5 The Configurable Math Library (CML) is distributed under the terms of the
6 Boost Software License, v1.0 (see cml/LICENSE for details).
8 *-----------------------------------------------------------------------*/
10 * @brief Multiply a matrix and a vector.
12 * @todo Implement smarter temporary generation.
14 * @todo Does it make sense to put mat-vec multiplication as a node into the
17 * @internal This does not need to return an expression type, since the
18 * temporary generation for the matrix result is handled automatically by the
19 * compiler. i.e. when used in an expression, the result is automatically
20 * included in the expression tree as a temporary by the compiler.
26 #include <cml/core/cml_meta.h>
27 #include <cml/vector/vector_expr.h>
28 #include <cml/matrix/matrix_expr.h>
29 #include <cml/matvec/matvec_promotions.h>
31 /* This is used below to create a more meaningful compile-time error when
32 * mat-vec mul is not provided with the right arguments:
34 struct mvmul_expects_one_matrix_and_one_vector_arg_error
;
35 struct mvmul_expects_one_vector_and_one_matrix_arg_error
;
40 /* For choosing the proper multiplication order: */
41 typedef true_type mul_Ax
;
42 typedef false_type mul_xA
;
44 /** Compute y = A*x. */
45 template<typename LeftT
, typename RightT
> inline
46 typename
et::MatVecPromote
<
47 typename
et::ExprTraits
<LeftT
>::result_type
,
48 typename
et::ExprTraits
<RightT
>::result_type
50 mul(const LeftT
& A
, const RightT
& x
, mul_Ax
)
53 typedef et::ExprTraits
<LeftT
> left_traits
;
54 typedef et::ExprTraits
<RightT
> right_traits
;
55 typedef typename
left_traits::result_tag left_result
;
56 typedef typename
right_traits::result_tag right_result
;
58 /* mul()[A*x] requires a matrix and a vector expression: */
60 (same_type
<left_result
, et::matrix_result_tag
>::is_true
61 && same_type
<right_result
, et::vector_result_tag
>::is_true
),
62 mvmul_expects_one_matrix_and_one_vector_arg_error
);
63 /* Note: parens are required here so that the preprocessor ignores the
67 /* Get result type: */
68 typedef typename
et::MatVecPromote
<
69 typename
left_traits::result_type
,
70 typename
right_traits::result_type
71 >::temporary_type result_type
;
73 /* Record size type: */
74 typedef typename
result_type::size_tag size_tag
;
77 size_t N
= et::CheckedSize(A
, x
, size_tag());
79 /* Initialize the new vector: */
80 result_type y
; cml::et::detail::Resize(y
, N
);
82 /* Compute y = A*x: */
83 typedef typename
result_type::value_type sum_type
;
84 for(size_t i
= 0; i
< N
; ++i
) {
85 /* XXX This should be unrolled. */
86 sum_type
sum(A(i
,0)*x
[0]);
87 for(size_t k
= 1; k
< x
.size(); ++k
) {
96 /** Compute y = x*A. */
97 template<typename LeftT
, typename RightT
> inline
98 typename
et::MatVecPromote
<
99 typename
et::ExprTraits
<LeftT
>::result_type
,
100 typename
et::ExprTraits
<RightT
>::result_type
102 mul(const LeftT
& x
, const RightT
& A
, mul_xA
)
105 typedef et::ExprTraits
<LeftT
> left_traits
;
106 typedef et::ExprTraits
<RightT
> right_traits
;
107 typedef typename
left_traits::result_tag left_result
;
108 typedef typename
right_traits::result_tag right_result
;
110 /* mul()[x*A] requires a vector and a matrix expression: */
111 CML_STATIC_REQUIRE_M(
112 (same_type
<left_result
, et::vector_result_tag
>::is_true
113 && same_type
<right_result
, et::matrix_result_tag
>::is_true
),
114 mvmul_expects_one_vector_and_one_matrix_arg_error
);
115 /* Note: parens are required here so that the preprocessor ignores the
119 /* Get result type: */
120 typedef typename
et::MatVecPromote
<
121 typename
left_traits::result_type
,
122 typename
right_traits::result_type
123 >::temporary_type result_type
;
125 /* Record size type: */
126 typedef typename
result_type::size_tag size_tag
;
128 /* Check the size: */
129 size_t N
= et::CheckedSize(x
, A
, size_tag());
131 /* Initialize the new vector: */
132 result_type y
; cml::et::detail::Resize(y
, N
);
134 /* Compute y = x*A: */
135 typedef typename
result_type::value_type sum_type
;
136 for(size_t i
= 0; i
< N
; ++i
) {
137 /* XXX This should be unrolled. */
138 sum_type
sum(x
[0]*A(0,i
));
139 for(size_t k
= 1; k
< x
.size(); ++k
) {
140 sum
+= (x
[k
]*A(k
,i
));
148 } // namespace detail
151 /** operator*() for a matrix and a vector. */
152 template<typename E1
, class AT1
, typename BO
, class L
,
153 typename E2
, class AT2
>
154 inline typename
et::MatVecPromote
<
155 matrix
<E1
,AT1
,BO
,L
>, vector
<E2
,AT2
>
157 operator*(const matrix
<E1
,AT1
,BO
,L
>& left
,
158 const vector
<E2
,AT2
>& right
)
160 return detail::mul(left
,right
,detail::mul_Ax());
163 /** operator*() for a matrix and a VectorXpr. */
164 template<typename E
, class AT
, class L
, typename BO
, typename XprT
>
165 inline typename
et::MatVecPromote
<
166 matrix
<E
,AT
,BO
,L
>, typename
XprT::result_type
168 operator*(const matrix
<E
,AT
,BO
,L
>& left
,
169 const et::VectorXpr
<XprT
>& right
)
171 /* Generate a temporary, and compute the right-hand expression: */
172 typename
et::VectorXpr
<XprT
>::temporary_type right_tmp
;
173 cml::et::detail::Resize(right_tmp
,right
.size());
176 return detail::mul(left
,right_tmp
,detail::mul_Ax());
179 /** operator*() for a MatrixXpr and a vector. */
180 template<typename XprT
, typename E
, class AT
>
181 inline typename
et::MatVecPromote
<
182 typename
XprT::result_type
, vector
<E
,AT
>
184 operator*(const et::MatrixXpr
<XprT
>& left
,
185 const vector
<E
,AT
>& right
)
187 /* Generate a temporary, and compute the left-hand expression: */
188 typename
et::MatrixXpr
<XprT
>::temporary_type left_tmp
;
189 cml::et::detail::Resize(left_tmp
,left
.rows(),left
.cols());
192 return detail::mul(left_tmp
,right
,detail::mul_Ax());
195 /** operator*() for a MatrixXpr and a VectorXpr. */
196 template<typename XprT1
, typename XprT2
>
197 inline typename
et::MatVecPromote
<
198 typename
XprT1::result_type
, typename
XprT2::result_type
200 operator*(const et::MatrixXpr
<XprT1
>& left
,
201 const et::VectorXpr
<XprT2
>& right
)
203 /* Generate a temporary, and compute the left-hand expression: */
204 typename
et::MatrixXpr
<XprT1
>::temporary_type left_tmp
;
205 cml::et::detail::Resize(left_tmp
,left
.rows(),left
.cols());
208 /* Generate a temporary, and compute the right-hand expression: */
209 typename
et::VectorXpr
<XprT2
>::temporary_type right_tmp
;
210 cml::et::detail::Resize(right_tmp
,right
.size());
213 return detail::mul(left_tmp
,right_tmp
,detail::mul_Ax());
216 /** operator*() for a vector and a matrix. */
217 template<typename E1
, class AT1
, typename E2
, class AT2
, typename BO
, class L
>
218 inline typename
et::MatVecPromote
<
219 vector
<E1
,AT1
>, matrix
<E2
,AT2
,BO
,L
>
221 operator*(const vector
<E1
,AT1
>& left
,
222 const matrix
<E2
,AT2
,BO
,L
>& right
)
224 return detail::mul(left
,right
,detail::mul_xA());
227 /** operator*() for a vector and a MatrixXpr. */
228 template<typename XprT
, typename E
, class AT
>
229 inline typename
et::MatVecPromote
<
230 typename
XprT::result_type
, vector
<E
,AT
>
232 operator*(const vector
<E
,AT
>& left
,
233 const et::MatrixXpr
<XprT
>& right
)
235 /* Generate a temporary, and compute the right-hand expression: */
236 typename
et::MatrixXpr
<XprT
>::temporary_type right_tmp
;
237 cml::et::detail::Resize(right_tmp
,right
.rows(),right
.cols());
240 return detail::mul(left
,right_tmp
,detail::mul_xA());
243 /** operator*() for a VectorXpr and a matrix. */
244 template<typename XprT
, typename E
, class AT
, typename BO
, class L
>
245 inline typename
et::MatVecPromote
<
246 typename
XprT::result_type
, matrix
<E
,AT
,BO
,L
>
248 operator*(const et::VectorXpr
<XprT
>& left
,
249 const matrix
<E
,AT
,BO
,L
>& right
)
251 /* Generate a temporary, and compute the left-hand expression: */
252 typename
et::VectorXpr
<XprT
>::temporary_type left_tmp
;
253 cml::et::detail::Resize(left_tmp
,left
.size());
256 return detail::mul(left_tmp
,right
,detail::mul_xA());
259 /** operator*() for a VectorXpr and a MatrixXpr. */
260 template<typename XprT1
, typename XprT2
>
261 inline typename
et::MatVecPromote
<
262 typename
XprT1::result_type
, typename
XprT2::result_type
264 operator*(const et::VectorXpr
<XprT1
>& left
,
265 const et::MatrixXpr
<XprT2
>& right
)
267 /* Generate a temporary, and compute the left-hand expression: */
268 typename
et::VectorXpr
<XprT1
>::temporary_type left_tmp
;
269 cml::et::detail::Resize(left_tmp
,left
.size());
272 /* Generate a temporary, and compute the right-hand expression: */
273 typename
et::MatrixXpr
<XprT2
>::temporary_type right_tmp
;
274 cml::et::detail::Resize(right_tmp
,right
.rows(),right
.cols());
277 return detail::mul(left_tmp
,right_tmp
,detail::mul_xA());
284 // -------------------------------------------------------------------------