Line data Source code
1 : #ifndef E4324F70_8D26_4E47_B84E_1A44466A0FBB
2 : #define E4324F70_8D26_4E47_B84E_1A44466A0FBB
3 :
4 : #include "math/linalg/covariance_matrix_full.h"
5 :
6 : #include "math/linalg/errors.h"
7 : #include "math/linalg/square_matrix.hpp" // IWYU pragma: keep
8 : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
9 :
10 : namespace tracking
11 : {
12 : namespace math
13 : {
14 :
15 : template <typename ValueType_, sint32 Size_>
16 13 : auto CovarianceMatrixFull<ValueType_, Size_>::FromDiagonal(const DiagonalMatrix<ValueType_, Size_>& diag) -> CovarianceMatrixFull
17 : {
18 13 : assert(diag.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
19 13 : return CovarianceMatrixFull{diag};
20 : }
21 :
22 : template <typename ValueType_, sint32 Size_>
23 37 : inline auto CovarianceMatrixFull<ValueType_, Size_>::inverse() const -> tl::expected<CovarianceMatrixFull, Errors>
24 : {
25 37 : const auto retVal = BaseSquareMatrix::decomposeLLT();
26 37 : if (retVal.has_value())
27 : {
28 32 : const auto& L = *retVal;
29 : // A * Ainv = eye(n,n) with A=L*L' from Cholesky decomposition
30 : // L*(L'*Ainv) = eye(n,n)
31 : // L*u = eye(n,n) -> solve for u using forward substitution on each column vector of eye(n,n)
32 32 : const auto u = L.solve(CovarianceMatrixFull::Identity());
33 : // L'*Ainv = u -> solve for Ainv using backward substitution
34 32 : math::SquareMatrix<ValueType_, Size_, true> cov{L.transpose().solve(u)};
35 32 : cov.symmetrize();
36 32 : return CovarianceMatrixFull{std::move(cov)};
37 32 : }
38 5 : return tl::unexpected<Errors>{retVal.error()};
39 37 : }
40 :
41 : template <typename ValueType_, sint32 Size_>
42 : template <bool IsRowMajor_>
43 60 : inline void CovarianceMatrixFull<ValueType_, Size_>::apaT(const tracking::math::SquareMatrix<ValueType_, Size_, IsRowMajor_>& A)
44 : {
45 60 : assert(this->isSymmetric() && "Covariance currently not symmetric");
46 : // calculate only the upper triangle part of P and fill lower triangle part
47 60 : BaseSquareMatrix cov{};
48 329 : for (sint32 i = 0; i < Size_; ++i)
49 : {
50 1048 : for (sint32 j = i; j < Size_; ++j)
51 : {
52 : ValueType_ element = static_cast<ValueType_>(0);
53 4699 : for (sint32 k = 0; k < Size_; ++k)
54 : {
55 24658 : for (sint32 l = 0; l < Size_; ++l)
56 : {
57 20738 : element += A.at_unsafe(i, k) * this->at_unsafe(k, l) * A.at_unsafe(j, l);
58 : }
59 : }
60 : // construct symmetric covariance matrix by filling both upper and lower triangle
61 779 : cov.at_unsafe(i, j) = element;
62 779 : cov.at_unsafe(j, i) = element;
63 : }
64 : }
65 60 : *this = CovarianceMatrixFull{std::move(cov)};
66 60 : }
67 :
68 : template <typename ValueType_, sint32 Size_>
69 : template <bool IsRowMajor_>
70 2 : inline auto CovarianceMatrixFull<ValueType_, Size_>::apaT(
71 : const tracking::math::SquareMatrix<ValueType_, Size_, IsRowMajor_>& A) const -> CovarianceMatrixFull
72 : {
73 3 : auto copy(*this);
74 2 : copy.apaT(A);
75 : return copy;
76 : }
77 :
78 : template <typename ValueType_, sint32 Size_>
79 2 : inline void CovarianceMatrixFull<ValueType_, Size_>::setVariance(const sint32 idx, const ValueType_ val)
80 : {
81 2 : constexpr auto zero = static_cast<ValueType_>(0.0);
82 8 : for (sint32 j = 0; j < Size_; ++j)
83 : {
84 6 : BaseSquareMatrix::at_unsafe(idx, j) = zero;
85 6 : BaseSquareMatrix::at_unsafe(j, idx) = zero;
86 : }
87 2 : BaseSquareMatrix::at_unsafe(idx, idx) = val;
88 2 : }
89 :
90 : } // namespace math
91 : } // namespace tracking
92 :
93 : #endif // E4324F70_8D26_4E47_B84E_1A44466A0FBB
|