11 #ifndef EIGEN_GENERAL_PRODUCT_H
12 #define EIGEN_GENERAL_PRODUCT_H
26 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
28 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20
38 #ifndef EIGEN_GPU_COMPILE_PHASE
39 is_large = MaxSize ==
Dynamic ||
40 Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
41 (Size==
Dynamic && MaxSize>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
45 value = is_large ? Large
53 typedef typename remove_all<Lhs>::type _Lhs;
54 typedef typename remove_all<Rhs>::type _Rhs;
78 value = selector::ret,
81 #ifdef EIGEN_DEBUG_PRODUCT
84 EIGEN_DEBUG_VAR(Rows);
85 EIGEN_DEBUG_VAR(Cols);
86 EIGEN_DEBUG_VAR(Depth);
87 EIGEN_DEBUG_VAR(rows_select);
88 EIGEN_DEBUG_VAR(cols_select);
89 EIGEN_DEBUG_VAR(depth_select);
90 EIGEN_DEBUG_VAR(value);
154 template<
int S
ide,
int StorageOrder,
bool BlasCompatible>
155 struct gemv_dense_selector;
163 template<
typename Scalar,
int Size,
int MaxSize>
166 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { eigen_internal_assert(
false &&
"should never be called");
return 0; }
169 template<
typename Scalar,
int Size>
172 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() {
return 0; }
175 template<
typename Scalar,
int Size,
int MaxSize>
182 #if EIGEN_MAX_STATIC_ALIGN_BYTES!=0
183 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize),0,EIGEN_PLAIN_ENUM_MIN(AlignedMax,PacketSize)> m_data;
184 EIGEN_STRONG_INLINE Scalar* data() {
return m_data.array; }
188 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)+(ForceAlignment?EIGEN_MAX_ALIGN_BYTES:0),0> m_data;
189 EIGEN_STRONG_INLINE Scalar* data() {
190 return ForceAlignment
191 ?
reinterpret_cast<Scalar*
>((internal::UIntPtr(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES-1))) + EIGEN_MAX_ALIGN_BYTES)
198 template<
int StorageOrder,
bool BlasCompatible>
201 template<
typename Lhs,
typename Rhs,
typename Dest>
202 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
207 ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
213 template<
typename Lhs,
typename Rhs,
typename Dest>
214 static inline void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
216 typedef typename Lhs::Scalar LhsScalar;
217 typedef typename Rhs::Scalar RhsScalar;
218 typedef typename Dest::Scalar ResScalar;
219 typedef typename Dest::RealScalar RealScalar;
222 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
224 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
228 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
229 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
231 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
232 * RhsBlasTraits::extractScalarFactor(rhs);
240 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
242 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
249 if(!MightCannotUseDest)
254 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
255 actualLhs.rows(), actualLhs.cols(),
256 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
257 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
265 const bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
266 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
268 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
269 evalToDest ? dest.data() : static_dest.data());
273 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
274 Index size = dest.size();
275 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
277 if(!alphaIsCompatible)
279 MappedDest(actualDestPtr, dest.size()).setZero();
280 compatibleAlpha = RhsScalar(1);
283 MappedDest(actualDestPtr, dest.size()) = dest;
287 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
288 actualLhs.rows(), actualLhs.cols(),
289 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
290 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
296 if(!alphaIsCompatible)
297 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
299 dest = MappedDest(actualDestPtr, dest.size());
307 template<
typename Lhs,
typename Rhs,
typename Dest>
308 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
310 typedef typename Lhs::Scalar LhsScalar;
311 typedef typename Rhs::Scalar RhsScalar;
312 typedef typename Dest::Scalar ResScalar;
315 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
317 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
318 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
323 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
324 * RhsBlasTraits::extractScalarFactor(rhs);
329 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
334 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
335 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
339 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
340 Index size = actualRhs.size();
341 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
349 <
Index,LhsScalar,LhsMapper,
RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
350 actualLhs.rows(), actualLhs.cols(),
351 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
352 RhsMapper(actualRhsPtr, 1),
353 dest.data(), dest.col(0).innerStride(),
360 template<
typename Lhs,
typename Rhs,
typename Dest>
361 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
366 const Index size = rhs.rows();
367 for(
Index k=0; k<size; ++k)
368 dest += (alpha*actual_rhs.coeff(k)) * lhs.col(k);
374 template<
typename Lhs,
typename Rhs,
typename Dest>
375 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
379 const Index rows = dest.rows();
380 for(
Index i=0; i<rows; ++i)
381 dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
397 template<
typename Derived>
398 template<
typename OtherDerived>
399 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
408 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
409 || OtherDerived::RowsAtCompileTime==
Dynamic
410 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
411 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
412 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
417 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
418 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
419 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
420 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
421 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
422 #ifdef EIGEN_DEBUG_PRODUCT
440 template<
typename Derived>
441 template<
typename OtherDerived>
442 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
447 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
448 || OtherDerived::RowsAtCompileTime==
Dynamic
449 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
450 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
451 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
456 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
457 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
458 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
459 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
460 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
467 #endif // EIGEN_PRODUCT_H