Line data Source code
1 : #ifndef D68AD8A7_E600_4155_9580_2BA4AE63E9D4
2 : #define D68AD8A7_E600_4155_9580_2BA4AE63E9D4
3 :
4 : #include "math/linalg/covariance_matrix_factored.h"
5 :
6 : #include "math/linalg/covariance_matrix_full.hpp" // IWYU pragma: keep
7 : #include "math/linalg/diagonal_matrix.hpp" // IWYU pragma: keep
8 : #include "math/linalg/matrix.hpp" // IWYU pragma: keep
9 : #include "math/linalg/matrix_column_view.hpp" // IWYU pragma: keep
10 : #include "math/linalg/matrix_row_view.hpp" // IWYU pragma: keep
11 : #include "math/linalg/modified_gram_schmidt.hpp" // IWYU pragma: keep
12 : #include "math/linalg/rank1_update.hpp" // IWYU pragma: keep
13 : #include "math/linalg/square_matrix.hpp" // IWYU pragma: keep
14 : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
15 :
16 : namespace tracking
17 : {
18 : namespace math
19 : {
20 :
21 : template <typename ValueType_, sint32 Size_>
22 114 : CovarianceMatrixFactored<ValueType_, Size_>::CovarianceMatrixFactored(const TriangularMatrix<ValueType_, Size_, false, true>& u,
23 : const DiagonalMatrix<ValueType_, Size_>& d)
24 114 : : _u{u}
25 114 : , _d{d}
26 : {
27 114 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
28 114 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
29 114 : }
30 :
31 : template <typename ValueType_, sint32 Size_>
32 30 : auto CovarianceMatrixFactored<ValueType_, Size_>::Identity() -> CovarianceMatrixFactored
33 : {
34 46 : CovarianceMatrixFactored cov{TriangularMatrix<ValueType_, Size_, false, true>::Identity(),
35 : DiagonalMatrix<ValueType_, Size_>::Identity()};
36 30 : return cov;
37 : }
38 :
39 : template <typename ValueType_, sint32 Size_>
40 13 : auto CovarianceMatrixFactored<ValueType_, Size_>::FromDiagonal(const DiagonalMatrix<ValueType_, Size_>& diag)
41 : -> CovarianceMatrixFactored
42 : {
43 13 : assert(diag.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
44 13 : return CovarianceMatrixFactored{TriangularMatrix<ValueType_, Size_, false, true>::Identity(), diag};
45 : }
46 :
47 : template <typename ValueType_, sint32 Size_>
48 19 : auto CovarianceMatrixFactored<ValueType_, Size_>::FromList(const std::initializer_list<std::initializer_list<ValueType_>>& u,
49 : const std::initializer_list<ValueType_>& d) -> CovarianceMatrixFactored
50 : {
51 19 : auto&& u_ = TriangularMatrix<ValueType_, Size_, false, true>::FromList(u);
52 19 : auto&& d_ = DiagonalMatrix<ValueType_, Size_>::FromList(d);
53 19 : return CovarianceMatrixFactored{std::move(u_), std::move(d_)};
54 19 : }
55 :
56 :
57 : template <typename ValueType_, sint32 Size_>
58 3 : void CovarianceMatrixFactored<ValueType_, Size_>::setIdentity()
59 : {
60 3 : _u.setIdentity();
61 3 : _d.setIdentity();
62 3 : }
63 :
64 : template <typename ValueType_, sint32 Size_>
65 87 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::operator()(sint32 row,
66 : sint32 col) const -> tl::expected<value_type, Errors>
67 : {
68 81 : if (!(row >= 0 && row < dim))
69 : {
70 4 : return tl::unexpected<Errors>{Errors::invalid_access_row};
71 : }
72 81 : if (!(col >= 0 && col < dim))
73 : {
74 2 : return tl::unexpected<Errors>{Errors::invalid_access_col};
75 : }
76 :
77 81 : return at_unsafe(row, col);
78 : }
79 :
80 : template <typename ValueType_, sint32 Size_>
81 498 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::at_unsafe(sint32 row, sint32 col) const -> ValueType_
82 : {
83 : // cov(row,col) == cov(col,row), so we swap row and col if row > col
84 498 : if (row > col)
85 : {
86 195 : std::swap(row, col);
87 : }
88 :
89 498 : ValueType_ result{};
90 : // calc row, col element of _u * _d * _u.transpose()
91 : // calc relevant elements of u*d to be rhs multiplied with uT(col::, col)==u(col, col::)
92 498 : Vector<ValueType_, Size_> ud{};
93 1642 : for (auto i = col; i < Size_; ++i)
94 : {
95 1144 : ud.at_unsafe(i) = _u.at_unsafe(row, i) * _d.at_unsafe(i);
96 : }
97 498 : MatrixColumnView<ValueType_, Size_, 1, true> udView{ud, 0, col, Size_ - 1};
98 498 : MatrixRowView<ValueType_, Size_, Size_, true> uTView{_u, col, col, Size_ - 1};
99 498 : result = uTView * udView; // calc the scalar product of ud*uT on relevant elements
100 498 : return result;
101 498 : }
102 :
103 : template <typename ValueType_, sint32 Size_>
104 279 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::operator()() const -> compose_type
105 : {
106 279 : math::SquareMatrix cov{_u * _d * _u.transpose()};
107 279 : cov.symmetrize();
108 279 : return compose_type{std::move(cov)};
109 279 : }
110 :
111 : template <typename ValueType_, sint32 Size_>
112 0 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::trace() const -> ValueType_
113 : {
114 0 : auto res = static_cast<ValueType_>(0);
115 0 : for (sint32 i = 0; i < Size_; ++i)
116 : {
117 0 : res += at_unsafe(i, i);
118 : }
119 0 : return res;
120 : }
121 :
122 : template <typename ValueType_, sint32 Size_>
123 12 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::inverse() const -> tl::expected<CovarianceMatrixFactored, Errors>
124 : {
125 : // we use the transpose of the input matrix to get a column major matrix
126 12 : const auto composed = tracking::math::SquareMatrix<ValueType_, Size_, false>{this->operator()().transpose()};
127 : // we decompose the input matrix into LDLt form, with L being a column major lower triangular matrix
128 12 : const auto ldlt = composed.decomposeLDLT();
129 12 : if (ldlt.has_value())
130 : {
131 8 : const auto [l, d] = ldlt.value();
132 : // we calc the inverse of L and D and map the inv(L).transpose() to U being again a row major upper triangular matrix
133 : // the resulting UDUt matrix describes the covariance matrix in information form, i.e. the inverse covariance matrix
134 8 : return CovarianceMatrixFactored{std::move(l.inverse().transpose()), std::move(d.inverse())};
135 8 : }
136 : else
137 : {
138 4 : return tl::unexpected<Errors>{ldlt.error()};
139 : }
140 12 : }
141 :
142 : template <typename ValueType_, sint32 Size_>
143 8 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::composed_inverse() const -> tl::expected<compose_type, Errors>
144 : {
145 8 : auto inv_u = _u.inverse();
146 8 : auto inv_d = _d.inverse();
147 8 : math::SquareMatrix cov{inv_u.transpose() * inv_d * inv_u};
148 8 : cov.symmetrize();
149 8 : return compose_type{std::move(cov)};
150 8 : }
151 :
152 : template <typename ValueType_, sint32 Size_>
153 : template <bool IsRowMajor_>
154 41 : inline void CovarianceMatrixFactored<ValueType_, Size_>::apaT(const SquareMatrix<ValueType_, Size_, IsRowMajor_>& A)
155 : {
156 41 : math::ModifiedGramSchmidt<ValueType_, Size_>::run(_u, _d, A);
157 41 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
158 41 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
159 41 : }
160 :
161 : template <typename ValueType_, sint32 Size_>
162 : template <bool IsRowMajor_>
163 2 : inline auto CovarianceMatrixFactored<ValueType_, Size_>::apaT(const SquareMatrix<ValueType_, Size_, IsRowMajor_>& A) const
164 : -> CovarianceMatrixFactored
165 : {
166 2 : CovarianceMatrixFactored cov{*this};
167 2 : cov.apaT(A);
168 2 : return cov;
169 : }
170 :
171 : template <typename ValueType_, sint32 Size_>
172 : template <sint32 SizeQ_>
173 25 : inline void CovarianceMatrixFactored<ValueType_, Size_>::thornton(const SquareMatrix<ValueType_, Size_, true>& Phi,
174 : const Matrix<ValueType_, Size_, SizeQ_, true>& G,
175 : const DiagonalMatrix<ValueType_, SizeQ_>& Q)
176 : {
177 25 : math::ModifiedGramSchmidt<ValueType_, Size_>::run(_u, _d, Phi, G, Q);
178 25 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
179 25 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
180 25 : }
181 :
182 : template <typename ValueType_, sint32 Size_>
183 73 : inline void CovarianceMatrixFactored<ValueType_, Size_>::rank1Update(const ValueType_ c, const Vector<ValueType_, Size_>& x)
184 : {
185 73 : math::Rank1Update<ValueType_, Size_, true>::run(_u, _d, c, x);
186 :
187 73 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
188 73 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
189 73 : }
190 :
191 : template <typename ValueType_, sint32 Size_>
192 3 : inline void CovarianceMatrixFactored<ValueType_, Size_>::setVariance(const sint32 idx, const ValueType_ val)
193 : {
194 3 : assert(val > static_cast<ValueType_>(0.0) && "Expected variance value greater than 0.0");
195 3 : auto A = SquareMatrix<ValueType_, Size_, true>::Identity();
196 3 : A.at_unsafe(idx, idx) = static_cast<ValueType_>(0.0);
197 3 : apaT(A);
198 3 : _d.at_unsafe(idx) = val;
199 3 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
200 3 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
201 3 : }
202 :
203 : template <typename ValueType_, sint32 Size_>
204 : template <sint32 SrcSize_, sint32 SrcCount_>
205 2 : inline void CovarianceMatrixFactored<ValueType_, Size_>::fill(const CovarianceMatrixFactored<ValueType_, SrcSize_>& other)
206 : {
207 2 : _u.template setBlock<SrcSize_, SrcCount_, 0, 0, 0, 0>(other._u);
208 2 : _d.template setBlock<SrcSize_, SrcCount_, 0, 0>(other._d);
209 2 : assert(_u.isUnitUpperTriangular() && "Bad triangular matrix not fullfilling the constraint IsUnitUpperTriangular");
210 2 : assert(_d.isPositiveDefinite() && "Bad diagonal matrix not fullfilling the constraint isPositiveDefinite");
211 2 : }
212 :
213 : template <typename ValueType_, sint32 Size_>
214 4 : inline void CovarianceMatrixFactored<ValueType_, Size_>::D(const sint32 idx, const ValueType_ val)
215 : {
216 4 : assert(val > static_cast<ValueType_>(0.0) && "Expected variance value greater than 0.0");
217 4 : _d.at_unsafe(idx) = val;
218 4 : }
219 : } // namespace math
220 : } // namespace tracking
221 :
222 : #endif // D68AD8A7_E600_4155_9580_2BA4AE63E9D4
|