Line data Source code
1 : #ifndef BA34774A_D1AA_44D4_BAD2_1845716F4E58
2 : #define BA34774A_D1AA_44D4_BAD2_1845716F4E58
3 :
4 : #include "math/linalg/diagonal_matrix.h"
5 :
6 : #include "math/linalg/matrix.hpp" // IWYU pragma: keep
7 : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
8 : #include "math/linalg/vector.hpp" // IWYU pragma: keep
9 :
10 : namespace tracking
11 : {
12 : namespace math
13 : {
14 :
15 : // Forward declaration to prevent cyclic includes
16 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
17 : class SquareMatrix;
18 :
19 : template <typename ValueType_, sint32 Size_>
20 657 : inline auto DiagonalMatrix<ValueType_, Size_>::Identity() -> DiagonalMatrix
21 : {
22 657 : DiagonalMatrix diag{};
23 657 : diag.setIdentity();
24 318 : return diag;
25 : }
26 :
27 : template <typename ValueType_, sint32 Size_>
28 661 : inline void DiagonalMatrix<ValueType_, Size_>::setIdentity()
29 : {
30 1001 : _data.setOnes();
31 : }
32 :
33 : template <typename ValueType_, sint32 Size_>
34 84 : inline auto DiagonalMatrix<ValueType_, Size_>::FromList(const std::initializer_list<ValueType_>& list) -> DiagonalMatrix
35 : {
36 84 : assert((list.size() == Size_) && "Mismatching size of intializer list");
37 :
38 84 : DiagonalMatrix diag{};
39 : // fill diagonal elements
40 84 : sint32 idx = 0;
41 342 : for (auto val : list)
42 : {
43 258 : diag.at_unsafe(idx++) = val;
44 : }
45 84 : return diag;
46 : }
47 :
48 : template <typename ValueType_, sint32 Size_>
49 50 : inline auto DiagonalMatrix<ValueType_, Size_>::FromList(const std::initializer_list<std::initializer_list<ValueType_>>& list)
50 : -> DiagonalMatrix
51 : {
52 50 : assert(list.size() == Size_);
53 50 : assert(list.begin()->size() == Size_);
54 :
55 50 : DiagonalMatrix diag{};
56 : // copy diagonal elements from list
57 50 : sint32 idx = 0;
58 197 : for (const auto& rowList : list)
59 : {
60 147 : assert((rowList.size() == Size_) && "Mismatching size of intializer list");
61 147 : diag.at_unsafe(idx) = *(rowList.begin() + idx);
62 147 : ++idx;
63 : }
64 50 : return diag;
65 : }
66 :
67 : template <typename ValueType_, sint32 Size_>
68 : template <sint32 SrcSize_, sint32 SrcCount_, sint32 SrcIdxBeg_, sint32 DstIdxBeg_>
69 84 : inline void DiagonalMatrix<ValueType_, Size_>::setBlock(const DiagonalMatrix<ValueType_, SrcSize_>& block)
70 : {
71 : static_assert(SrcCount_ > 1, "use scalar access operator for block copy size == 1");
72 : static_assert(SrcIdxBeg_ + SrcCount_ <= SrcSize_, "copy to many rows from src");
73 :
74 : static_assert(DstIdxBeg_ + SrcCount_ <= Size_, "copy to many rows to dst");
75 :
76 84 : sint32 dstIdx = DstIdxBeg_;
77 296 : for (auto srcIdx = SrcIdxBeg_; srcIdx < SrcIdxBeg_ + SrcCount_; ++srcIdx)
78 : {
79 212 : _data.at_unsafe(dstIdx++) = block.at_unsafe(srcIdx);
80 : }
81 84 : }
82 :
83 : template <typename ValueType_, sint32 Size_>
84 : template <sint32 Cols_, bool IsRowMajor_>
85 4 : inline auto DiagonalMatrix<ValueType_, Size_>::operator*(const Matrix<ValueType_, Size_, Cols_, IsRowMajor_>& mat) const
86 : -> Matrix<ValueType_, Size_, Cols_, IsRowMajor_>
87 : {
88 : // each row is multiplied by the corresponding diagonal row element
89 4 : auto result{mat};
90 18 : for (auto row = 0; row < Size_; ++row)
91 : {
92 14 : const ValueType_ val = _data.at_unsafe(row);
93 37 : for (auto col = 0; col < Cols_; ++col)
94 : {
95 23 : result.at_unsafe(row, col) *= val;
96 : }
97 : }
98 4 : return result;
99 : }
100 :
101 : template <typename ValueType_, sint32 Size_>
102 : template <bool IsLower_, bool IsRowMajor_>
103 20 : inline auto DiagonalMatrix<ValueType_, Size_>::operator*(const TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>& mat)
104 : const -> TriangularMatrix<ValueType_, Size_, IsLower_, IsRowMajor_>
105 : {
106 : // each row is multiplied by the corresponding diagonal row element
107 20 : auto result{mat};
108 : if (IsLower_)
109 : {
110 88 : for (auto row = 0; row < Size_; ++row)
111 : {
112 69 : const ValueType_ val = _data.at_unsafe(row);
113 249 : for (auto col = 0; col <= row; ++col)
114 : {
115 180 : result.at_unsafe(row, col) *= val;
116 : }
117 : }
118 : }
119 : else
120 : {
121 4 : for (auto row = 0; row < Size_; ++row)
122 : {
123 3 : const ValueType_ val = _data.at_unsafe(row);
124 9 : for (auto col = row; col < Size_; ++col)
125 : {
126 6 : result.at_unsafe(row, col) *= val;
127 : }
128 : }
129 : }
130 20 : return result;
131 : }
132 :
133 : template <typename ValueType_, sint32 Size_>
134 1 : inline auto DiagonalMatrix<ValueType_, Size_>::operator*(const DiagonalMatrix& mat) const -> DiagonalMatrix
135 : {
136 : // copy *this
137 1 : auto result{*this};
138 1 : result *= mat;
139 : return result;
140 : }
141 :
142 : template <typename ValueType_, sint32 Size_>
143 1 : inline auto DiagonalMatrix<ValueType_, Size_>::operator*(const ValueType_ scalar) const -> DiagonalMatrix
144 : {
145 : // copy *this
146 1 : auto result{*this};
147 1 : result *= scalar;
148 : return result;
149 : }
150 :
151 : template <typename ValueType_, sint32 Size_>
152 2 : inline void DiagonalMatrix<ValueType_, Size_>::operator*=(const DiagonalMatrix& mat)
153 : {
154 : // element-wise multiplication of the elements on both diagonals
155 8 : for (auto idx = 0; idx < Size_; ++idx)
156 : {
157 6 : _data.at_unsafe(idx) *= mat.at_unsafe(idx);
158 : }
159 2 : }
160 :
161 : template <typename ValueType_, sint32 Size_>
162 2 : inline void DiagonalMatrix<ValueType_, Size_>::operator*=(const ValueType_ scalar)
163 : {
164 : // element-wise multiplication of the elements on both diagonals
165 8 : for (auto idx = 0; idx < Size_; ++idx)
166 : {
167 6 : _data.at_unsafe(idx) *= scalar;
168 : }
169 2 : }
170 :
171 : template <typename ValueType_, sint32 Size_>
172 19 : inline auto DiagonalMatrix<ValueType_, Size_>::inverse() const -> DiagonalMatrix
173 : {
174 19 : DiagonalMatrix tmp{*this};
175 19 : tmp.inverse();
176 : return tmp;
177 : }
178 :
179 : template <typename ValueType_, sint32 Size_>
180 20 : inline void DiagonalMatrix<ValueType_, Size_>::inverse()
181 : {
182 100 : for (sint32 idx = 0; idx < Size_; ++idx)
183 : {
184 80 : assert((static_cast<ValueType_>(0) < this->at_unsafe(idx)) && "inverse not possible");
185 80 : _data.at_unsafe(idx) = static_cast<ValueType_>(1) / _data.at_unsafe(idx);
186 : }
187 20 : }
188 :
189 : template <typename ValueType_, sint32 Size_>
190 7 : inline auto DiagonalMatrix<ValueType_, Size_>::trace() const -> ValueType_
191 : {
192 7 : ValueType_ sum{static_cast<ValueType_>(0)};
193 26 : for (auto idx = 0; idx < Size_; ++idx)
194 : {
195 19 : sum += _data.at_unsafe(idx);
196 : }
197 7 : return sum;
198 : }
199 :
200 : template <typename ValueType_, sint32 Size_>
201 17 : inline auto DiagonalMatrix<ValueType_, Size_>::determinant() const -> ValueType_
202 : {
203 17 : ValueType_ det{static_cast<ValueType_>(1)};
204 85 : for (auto idx = 0; idx < Size_; ++idx)
205 : {
206 68 : det *= _data.at_unsafe(idx);
207 : }
208 17 : return det;
209 : }
210 :
211 : template <typename ValueType_, sint32 Size_>
212 286 : inline auto DiagonalMatrix<ValueType_, Size_>::isPositiveDefinite() const -> bool
213 : {
214 286 : auto isValid = true;
215 1471 : for (auto idx = 0; idx < Size_; ++idx)
216 : {
217 1186 : isValid = isValid && (static_cast<ValueType_>(0) < _data.at_unsafe(idx));
218 : }
219 286 : return isValid;
220 : }
221 :
222 : template <typename ValueType_, sint32 Size_>
223 10 : inline auto DiagonalMatrix<ValueType_, Size_>::isPositiveSemiDefinite() const -> bool
224 : {
225 10 : auto isValid = true;
226 40 : for (auto idx = 0; idx < Size_; ++idx)
227 : {
228 32 : isValid = isValid && (static_cast<ValueType_>(0) <= _data.at_unsafe(idx));
229 : }
230 10 : return isValid;
231 : }
232 :
233 :
234 : // ------ non-member functions ---------------------------------------------------------------------------------------------------
235 :
236 : template <typename ValueType_, sint32 Rows_, sint32 Cols_, bool IsRowMajor_>
237 44 : auto operator*(const Matrix<ValueType_, Rows_, Cols_, IsRowMajor_>& mat,
238 : const DiagonalMatrix<ValueType_, Cols_>& diag) -> Matrix<ValueType_, Rows_, Cols_, IsRowMajor_>
239 : {
240 : // each column is multiplied by the corresponding diagonal column element
241 44 : auto result{mat};
242 133 : for (auto col = 0; col < Cols_; ++col)
243 : {
244 526 : for (auto row = 0; row < Rows_; ++row)
245 : {
246 437 : result.at_unsafe(row, col) *= diag.at_unsafe(col);
247 : }
248 : }
249 44 : return result;
250 : }
251 :
252 : } // namespace math
253 : } // namespace tracking
254 :
255 : #endif // BA34774A_D1AA_44D4_BAD2_1845716F4E58
|