Line data Source code
1 : #ifndef ADB29DD2_C5B0_4217_8728_B612EFF95F07
2 : #define ADB29DD2_C5B0_4217_8728_B612EFF95F07
3 :
4 : #include "math/linalg/square_matrix.h"
5 :
6 : #include "math/linalg/diagonal_matrix.hpp" // IWYU pragma: keep
7 : #include "math/linalg/matrix_column_view.hpp" // IWYU pragma: keep
8 : #include "math/linalg/matrix_row_view.hpp" // IWYU pragma: keep
9 : #include "math/linalg/square_matrix_decompositions.hpp" // IWYU pragma: keep
10 : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
11 : #include "math/linalg/vector.hpp" // IWYU pragma: keep
12 : #include <cmath> // sqrt
13 :
14 : namespace tracking
15 : {
16 : namespace math
17 : {
18 :
19 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
20 592 : SquareMatrix<ValueType_, Size_, IsRowMajor_>::SquareMatrix(const DiagonalMatrix<ValueType_, Size_>& other)
21 : {
22 3192 : for (auto idx = 0; idx < Size_; ++idx)
23 : {
24 2600 : this->at_unsafe(idx, idx) = other.at_unsafe(idx);
25 : }
26 592 : }
27 :
28 :
29 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
30 204 : inline void SquareMatrix<ValueType_, Size_, IsRowMajor_>::setIdentity()
31 : {
32 204 : *this = SquareMatrix{DiagonalMatrix<ValueType_, Size_>::Identity()};
33 : }
34 :
35 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
36 375 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::Identity() -> SquareMatrix
37 : {
38 375 : return SquareMatrix{DiagonalMatrix<ValueType_, Size_>::Identity()};
39 : }
40 :
41 :
42 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
43 290 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::FromList(
44 : const std::initializer_list<std::initializer_list<ValueType_>>& list) -> SquareMatrix
45 : {
46 290 : return SquareMatrix{BaseMatrix::FromList(list)};
47 : }
48 :
49 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
50 6 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::trace() const -> ValueType_
51 : {
52 6 : ValueType_ result{0};
53 24 : for (auto i = 0; i < Size_; ++i)
54 : {
55 18 : result += this->at_unsafe(i, i);
56 : }
57 6 : return result;
58 : }
59 :
60 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
61 16 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::determinant() const -> ValueType_
62 : {
63 : // Create a copy of the matrix for LU decomposition
64 16 : SquareMatrix<ValueType_, Size_, IsRowMajor_> luMatrix{*this};
65 16 : sint32 permutationCount{0};
66 :
67 : // Perform LU decomposition with partial pivoting
68 80 : for (auto k = 0; k < Size_; ++k)
69 : {
70 : // Partial pivoting: find the row with maximum element in current column
71 66 : auto maxRow = k;
72 66 : ValueType_ maxVal = std::abs(luMatrix.at_unsafe(k, k));
73 :
74 187 : for (auto i = k + 1; i < Size_; ++i)
75 : {
76 121 : const auto absVal = std::abs(luMatrix.at_unsafe(i, k));
77 121 : if (absVal > maxVal)
78 : {
79 7 : maxVal = absVal;
80 7 : maxRow = i;
81 : }
82 : }
83 :
84 : // Swap rows if necessary
85 66 : if (maxRow != k)
86 : {
87 26 : for (auto j = 0; j < Size_; ++j)
88 : {
89 20 : std::swap(luMatrix.at_unsafe(k, j), luMatrix.at_unsafe(maxRow, j));
90 : }
91 6 : permutationCount++;
92 : }
93 :
94 : // Check for singular matrix
95 66 : if (std::abs(luMatrix.at_unsafe(k, k)) < std::numeric_limits<ValueType_>::epsilon())
96 : {
97 : return static_cast<ValueType_>(0);
98 : }
99 :
100 : // Perform elimination for rows below the current one
101 184 : for (auto i = k + 1; i < Size_; ++i)
102 : {
103 120 : const auto factor = luMatrix.at_unsafe(i, k) / luMatrix.at_unsafe(k, k);
104 :
105 610 : for (auto j = k; j < Size_; ++j)
106 : {
107 490 : luMatrix.at_unsafe(i, j) -= factor * luMatrix.at_unsafe(k, j);
108 : }
109 : }
110 : }
111 :
112 : // Calculate determinant as product of diagonal elements, multiplied by (-1)^permutationCount
113 : ValueType_ det{1};
114 75 : for (auto i = 0; i < Size_; ++i)
115 : {
116 61 : det *= luMatrix.at_unsafe(i, i);
117 : }
118 :
119 : // Apply sign based on number of row permutations
120 14 : if (permutationCount % 2 != 0)
121 : {
122 4 : det = -det;
123 : }
124 :
125 : return det;
126 16 : }
127 :
128 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
129 : template <sint32 Cols_, bool IsRowMajor2_>
130 122 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::qrSolve(const Matrix<ValueType_, Size_, Cols_, IsRowMajor2_>& b) const
131 : -> Matrix<ValueType_, Size_, Cols_, !IsRowMajor_>
132 : {
133 122 : const auto [Q, R] = householderQR();
134 122 : return R.solve(Q.transpose() * b);
135 122 : }
136 :
137 :
138 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
139 42 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::inverse() const -> SquareMatrix<ValueType_, Size_, !IsRowMajor_>
140 : {
141 : // TODO(matthias): optimization - implement SPD detection and prefer Cholesky over QR for small sizes (Size_ <= 10)
142 : // saves time without complexity overhead; add in-place variant to avoid copies and improve cache efficiency
143 42 : return qrSolve(SquareMatrix::Identity());
144 : }
145 :
146 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
147 391 : inline void SquareMatrix<ValueType_, Size_, IsRowMajor_>::symmetrize()
148 : {
149 : // keep diagonal elements unchanged, average upper and lower triangle elements to enforce symmetry
150 2231 : for (sint32 row = 0; row < Size_; ++row)
151 : {
152 5530 : for (sint32 col = row + 1; col < Size_; ++col)
153 : {
154 3690 : const ValueType_ avg = (this->at_unsafe(row, col) + this->at_unsafe(col, row)) * static_cast<ValueType_>(0.5);
155 3690 : this->at_unsafe(row, col) = avg;
156 3690 : this->at_unsafe(col, row) = avg;
157 : }
158 : }
159 391 : }
160 :
161 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
162 820 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isSymmetric(ValueType_ tolerance) const -> bool
163 : {
164 : // check all off diagonal elements
165 : // TODO(matthias): optimization - use template-based loop unrolling for fixed small sizes (Size_ <= 10)
166 : // to reduce loop overhead; optimize cache access in matrix ops, prioritize simplicity since n is small
167 4382 : for (auto row = 0; row < Size_; ++row)
168 : {
169 10201 : for (auto col = row + 1; col < Size_; ++col)
170 : {
171 6639 : const auto absDiff = std::abs(this->at_unsafe(row, col) - this->at_unsafe(col, row));
172 6639 : if (absDiff > tolerance)
173 : {
174 : return false;
175 : }
176 : }
177 : }
178 : return true;
179 : }
180 :
181 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
182 51 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isPositiveDefinite() const -> bool
183 : {
184 : // Try Cholesky decomposition - if it succeeds, matrix is positive definite
185 51 : const auto choleskyResult = this->decomposeLLT();
186 51 : const bool result = choleskyResult.has_value();
187 51 : return result;
188 51 : }
189 :
190 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
191 51 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isPositiveSemiDefinite() const -> bool
192 : {
193 : // we can only use Cholesky decomposition which has more strict checks
194 51 : return isPositiveDefinite();
195 : }
196 :
197 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
198 115 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::hasStrictlyPositiveDiagonalElems() const -> bool
199 : {
200 115 : sint32 j{0};
201 : // check all diagonal elements
202 548 : while ((j < Size_) && (this->at_unsafe(j, j) > static_cast<ValueType_>(0)))
203 : {
204 433 : ++j;
205 : }
206 115 : return j == Size_;
207 : }
208 :
209 : // ============================================================================
210 : // Matrix Property Check Functions
211 : // ============================================================================
212 :
213 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
214 7 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isOrthogonal(ValueType_ tolerance) const -> bool
215 : {
216 7 : const auto QtQ = this->transpose() * (*this);
217 7 : const auto identity = SquareMatrix<ValueType_, Size_, !IsRowMajor_>::Identity();
218 :
219 20 : for (auto i = 0; i < Size_; ++i)
220 : {
221 50 : for (auto j = 0; j < Size_; ++j)
222 : {
223 37 : const auto diff = std::abs(QtQ.at_unsafe(i, j) - identity.at_unsafe(i, j));
224 37 : if (diff > tolerance)
225 : {
226 : return false;
227 : }
228 : }
229 : }
230 : return true;
231 7 : }
232 :
233 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
234 7 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isUpperTriangular(ValueType_ tolerance) const -> bool
235 : {
236 24 : for (auto i = 0; i < Size_; ++i)
237 : {
238 34 : for (auto j = 0; j < i; ++j)
239 : {
240 17 : if (std::abs(this->at_unsafe(i, j)) > tolerance)
241 : {
242 : return false;
243 : }
244 : }
245 : }
246 : return true;
247 : }
248 :
249 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
250 3 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::isLowerTriangular(ValueType_ tolerance) const -> bool
251 : {
252 9 : for (auto i = 0; i < Size_; ++i)
253 : {
254 13 : for (auto j = i + 1; j < Size_; ++j)
255 : {
256 7 : if (std::abs(this->at_unsafe(i, j)) > tolerance)
257 : {
258 : return false;
259 : }
260 : }
261 : }
262 : return true;
263 : }
264 :
265 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
266 6 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::hasUnitDiagonal(ValueType_ tolerance) const -> bool
267 : {
268 21 : for (auto i = 0; i < Size_; ++i)
269 : {
270 16 : if (std::abs(this->at_unsafe(i, i) - static_cast<ValueType_>(1.0)) > tolerance)
271 : {
272 : return false;
273 : }
274 : }
275 : return true;
276 : }
277 :
278 : } // namespace math
279 : } // namespace tracking
280 :
281 : #endif // ADB29DD2_C5B0_4217_8728_B612EFF95F07
|