Line data Source code
1 : #ifndef E4D3E13A_DB2E_427E_BA99_4F251275B082
2 : #define E4D3E13A_DB2E_427E_BA99_4F251275B082
3 :
4 : #include "math/linalg/triangular_matrix.h"
5 :
6 : #include "math/linalg/matrix.hpp" // IWYU pragma: keep
7 :
8 : namespace tracking
9 : {
10 : namespace math
11 : {
12 :
13 : // Forward declarations to prevent cyclic includes
14 : template <typename ValueType_, sint32 Size_>
15 : class DiagonalMatrix;
16 :
17 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
18 : class SquareMatrix;
19 :
20 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
21 13 : inline TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::TriangularMatrix(const BaseSquareMatrix& other)
22 13 : : BaseSquareMatrix{}
23 : {
24 : // copy triangular elements from other
25 58 : for (sint32 row = 0; row < Size_; ++row)
26 : {
27 45 : this->at_unsafe(row, row) = other.at_unsafe(row, row);
28 111 : for (sint32 col = row + 1; col < Size_; ++col)
29 : {
30 66 : const sint32 rowIdx = IsLower_ ? col : row;
31 66 : const sint32 colIdx = IsLower_ ? row : col;
32 :
33 66 : this->at_unsafe(rowIdx, colIdx) = other.at_unsafe(rowIdx, colIdx);
34 : }
35 : }
36 13 : }
37 :
38 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
39 126 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::FromList(
40 : const std::initializer_list<std::initializer_list<ValueType_>>& list) -> TriangularMatrix
41 : {
42 126 : return TriangularMatrix{BaseSquareMatrix::FromList(list)};
43 : }
44 :
45 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
46 : template <sint32 SrcSize, sint32 SrcCount, sint32 SrcRowBeg, sint32 SrcColBeg, sint32 DstRowBeg, sint32 DstColBeg>
47 8 : inline void TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::setBlock(
48 : const TriangularMatrix<ValueType_, SrcSize, IsLower_, IsRowMajor_>& block)
49 : {
50 : static_assert(SrcCount > 1, "use scalar access operator for block copy size == 1");
51 : static_assert(SrcRowBeg + SrcCount <= SrcSize, "copy to many rows from src");
52 : static_assert(SrcColBeg + SrcCount <= SrcSize, "copy to many cols from src");
53 :
54 : static_assert(DstRowBeg + SrcCount <= Size_, "copy to many rows to dst");
55 : static_assert(DstColBeg + SrcCount <= Size_, "copy to many cols to dst");
56 :
57 8 : constexpr bool checkSrcAccess = IsLower_ ? SrcRowBeg >= SrcColBeg : SrcRowBeg <= SrcColBeg;
58 : static_assert(checkSrcAccess, "accessing off diagonal part of src triangular matrix");
59 8 : constexpr bool checkDstAccess = IsLower_ ? DstRowBeg >= DstColBeg : DstRowBeg <= DstColBeg;
60 : static_assert(checkDstAccess, "accessing off diagonal part of dst triangular matrix");
61 :
62 28 : for (sint32 row = 0; row < SrcCount; ++row)
63 : {
64 58 : for (sint32 col = 0; col <= row; ++col)
65 : {
66 : if (IsLower_) // will be optimized out because isLower can be deduced at compile time
67 : {
68 9 : this->at_unsafe(DstRowBeg + row, DstColBeg + col) = block.at_unsafe(SrcRowBeg + row, SrcColBeg + col);
69 : }
70 : else
71 : {
72 29 : this->at_unsafe(DstRowBeg + col, DstColBeg + row) = block.at_unsafe(SrcRowBeg + col, SrcColBeg + row);
73 : }
74 : }
75 : }
76 8 : }
77 :
78 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
79 : template <sint32 Cols_, bool IsRowMajor2_>
80 325 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*(
81 : const Matrix<ValueType_, Size_, Cols_, IsRowMajor2_>& mat) const -> Matrix<ValueType_, Size_, Cols_, IsRowMajor2_>
82 : {
83 325 : Matrix<ValueType_, Size_, Cols_, IsRowMajor2_> result{};
84 : if (IsLower_)
85 : {
86 91 : for (auto i = 0; i < Size_; ++i)
87 : {
88 231 : for (auto k = 0; k <= i; ++k)
89 : {
90 772 : for (auto j = 0; j < Cols_; ++j)
91 : {
92 611 : result.at_unsafe(i, j) += this->at_unsafe(i, k) * mat.at_unsafe(k, j);
93 : }
94 : }
95 : }
96 : }
97 : else
98 : {
99 1754 : for (auto i = 0; i < Size_; ++i)
100 : {
101 5858 : for (auto k = i; k < Size_; ++k)
102 : {
103 27746 : for (auto j = 0; j < Cols_; ++j)
104 : {
105 23338 : result.at_unsafe(i, j) += this->at_unsafe(i, k) * mat.at_unsafe(k, j);
106 : }
107 : }
108 : }
109 : }
110 325 : return result;
111 : }
112 :
113 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
114 3 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*(const TriangularMatrix& mat) const
115 : -> TriangularMatrix
116 : {
117 3 : TriangularMatrix result{};
118 : if (IsLower_)
119 : {
120 7 : for (auto i = 0; i < Size_; ++i)
121 : {
122 14 : for (auto k = 0; k <= i; ++k)
123 : {
124 23 : for (auto j = 0; j <= k; ++j)
125 : {
126 14 : result.at_unsafe(i, j) += this->at_unsafe(i, k) * mat.at_unsafe(k, j);
127 : }
128 : }
129 : }
130 : }
131 : else
132 : {
133 4 : for (auto i = 0; i < Size_; ++i)
134 : {
135 9 : for (auto k = i; k < Size_; ++k)
136 : {
137 16 : for (auto j = k; j < Size_; ++j)
138 : {
139 10 : result.at_unsafe(i, j) += this->at_unsafe(i, k) * mat.at_unsafe(k, j);
140 : }
141 : }
142 : }
143 : }
144 3 : return result;
145 : }
146 :
147 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
148 1 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*(
149 : const TriangularMatrix<ValueType_, Size_, !IsLower_, IsRowMajor_>& mat) const -> BaseSquareMatrix
150 : {
151 1 : BaseSquareMatrix other{mat};
152 1 : return BaseSquareMatrix{this->operator*(other)};
153 1 : }
154 :
155 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
156 299 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*(
157 : const DiagonalMatrix<ValueType_, Size_>& diag) const -> TriangularMatrix
158 : {
159 : // each column is multiplied by the corresponding diagonal column element
160 299 : auto result{*this};
161 : if (IsLower_)
162 : {
163 57 : for (auto col = 0; col < Size_; ++col)
164 : {
165 145 : for (auto row = col; row < Size_; ++row)
166 : {
167 101 : result.at_unsafe(row, col) *= diag.at_unsafe(col);
168 : }
169 : }
170 : }
171 : else
172 : {
173 1670 : for (auto col = 0; col < Size_; ++col)
174 : {
175 5618 : for (auto row = 0; row <= col; ++row)
176 : {
177 4234 : result.at_unsafe(row, col) *= diag.at_unsafe(col);
178 : }
179 : }
180 : }
181 299 : return result;
182 : }
183 :
184 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
185 1 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*(const ValueType_ scalar) const
186 : -> TriangularMatrix
187 : {
188 1 : auto result{*this};
189 1 : result *= scalar;
190 : return result;
191 : }
192 :
193 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
194 131 : inline void TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator*=(const ValueType_ scalar)
195 : {
196 : // TODO(matthias): can be optimized as soon as elements are stored in array instead of a SquareMatrix
197 : if (IsLower_)
198 : {
199 4 : for (sint32 row = 0; row < Size_; ++row)
200 : {
201 9 : for (sint32 col = 0; col <= row; ++col)
202 : {
203 6 : this->at_unsafe(row, col) *= scalar;
204 : }
205 : }
206 : }
207 : else
208 : {
209 761 : for (sint32 row = 0; row < Size_; ++row)
210 : {
211 2555 : for (sint32 col = row; col < Size_; ++col)
212 : {
213 1924 : this->at_unsafe(row, col) *= scalar;
214 : }
215 : }
216 : }
217 131 : }
218 :
219 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
220 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator()(sint32 row, sint32 col) const
221 : -> tl::expected<ValueType_, Errors>
222 : {
223 : if (!((IsLower_ && (row >= col)) || (!IsLower_ && (row <= col))))
224 : {
225 : return tl::unexpected<Errors>{Errors::invalid_access_idx};
226 : }
227 : return BaseSquareMatrix::operator()(row, col);
228 : }
229 :
230 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
231 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::operator()(sint32 row, sint32 col)
232 : -> tl::expected<std::reference_wrapper<ValueType_>, Errors>
233 : {
234 : if (!((IsLower_ && (row >= col)) || (!IsLower_ && (row <= col))))
235 : {
236 : return tl::unexpected<Errors>{Errors::invalid_access_idx};
237 : }
238 : return BaseSquareMatrix::operator()(row, col);
239 : }
240 :
241 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
242 37295 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::at_unsafe(sint32 row, sint32 col) const -> ValueType_
243 : {
244 37295 : assert((IsLower_ ? row >= col : row <= col) && "accessing off-triangular elements");
245 37295 : return BaseSquareMatrix::at_unsafe(row, col);
246 : }
247 :
248 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
249 18528 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::at_unsafe(sint32 row, sint32 col) -> ValueType_&
250 : {
251 18528 : assert((IsLower_ ? row >= col : row <= col) && "accessing off-triangular elements");
252 18528 : return BaseSquareMatrix::at_unsafe(row, col);
253 : }
254 :
255 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
256 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::transpose() const -> const transpose_type&
257 : {
258 : return reinterpret_cast<const transpose_type&>(*this);
259 : }
260 :
261 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
262 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::transpose() -> transpose_type&
263 : {
264 : return reinterpret_cast<transpose_type&>(*this);
265 : }
266 :
267 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
268 : template <sint32 Cols_, bool IsRowMajor2_>
269 228 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::solve(
270 : const Matrix<ValueType_, Size_, Cols_, IsRowMajor2_>& b) const -> Matrix<ValueType_, Size_, Cols_, IsRowMajor2_>
271 : {
272 228 : Matrix<ValueType_, Size_, Cols_, IsRowMajor2_> x{};
273 : if constexpr (IsLower_)
274 : {
275 255 : for (auto k = 0; k < Cols_; ++k)
276 : {
277 956 : for (auto row = 0; row < Size_; ++row)
278 : {
279 : ValueType_ sum{};
280 2042 : for (auto col = 0; col < row; ++col)
281 : {
282 1280 : sum += this->at_unsafe(row, col) * x.at_unsafe(col, k);
283 : }
284 762 : x.at_unsafe(row, k) = (b.at_unsafe(row, k) - sum) / this->at_unsafe(row, row);
285 : }
286 : }
287 : }
288 : else
289 : {
290 768 : for (auto k = 0; k < Cols_; ++k)
291 : {
292 3478 : for (auto row = Size_ - 1; row >= 0; --row)
293 : {
294 : ValueType_ sum{};
295 8756 : for (auto col = Size_ - 1; col > row; --col)
296 : {
297 5879 : sum += this->at_unsafe(row, col) * x.at_unsafe(col, k);
298 : }
299 2877 : x.at_unsafe(row, k) = (b.at_unsafe(row, k) - sum) / this->at_unsafe(row, row);
300 : }
301 : }
302 : }
303 228 : return x;
304 : }
305 :
306 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
307 26 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::inverse() const -> TriangularMatrix
308 : {
309 26 : return TriangularMatrix{this->solve(BaseSquareMatrix::Identity())};
310 : }
311 :
312 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
313 10 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::determinant() const -> ValueType_
314 : {
315 10 : ValueType_ det{1.0};
316 38 : for (auto idx = 0; idx < Size_; ++idx)
317 : {
318 28 : det *= this->at_unsafe(idx, idx);
319 : }
320 10 : return det;
321 : }
322 :
323 : template <typename ValueType_, sint32 Size_, bool IsLower_, bool IsRowMajor_>
324 262 : inline auto TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>::isUnitUpperTriangular() const -> bool
325 : {
326 262 : auto isValid = true;
327 1375 : for (auto idx = 0; idx < Size_; ++idx)
328 : {
329 1114 : isValid = isValid && (static_cast<ValueType_>(1.0) == this->at_unsafe(idx, idx));
330 : }
331 262 : return isValid;
332 : }
333 :
334 :
335 : } // namespace math
336 : } // namespace tracking
337 :
338 : #endif // E4D3E13A_DB2E_427E_BA99_4F251275B082
|