Line data Source code
1 : #ifndef C5DC5CCE_5C5B_4EAF_813E_3F6FEDDF09FA
2 : #define C5DC5CCE_5C5B_4EAF_813E_3F6FEDDF09FA
3 :
4 : #include "math/analysis/functions.h"
5 : #include "math/linalg/diagonal_matrix.hpp"
6 : #include "math/linalg/matrix_column_view.hpp"
7 : #include "math/linalg/matrix_row_view.hpp" // IWYU pragma: keep
8 : #include "math/linalg/triangular_matrix.hpp"
9 :
10 : namespace tracking
11 : {
12 : namespace math
13 : {
14 :
15 : // Forward declaration to prevent cyclic includes
16 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
17 : class SquareMatrix;
18 :
19 : // Householder QR decomposition
20 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
21 128 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::householderQR() const
22 : -> std::pair<SquareMatrix, TriangularMatrix<ValueType_, Size_, false, IsRowMajor_>>
23 : {
24 : // implementation based on https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
25 :
26 : // Initially, Q is an identity matrix because no orthogonal transformations have been applied yet.
27 128 : SquareMatrix Q{SquareMatrix::Identity()};
28 : // Initializes the upper triangular matrix R as a copy of the input matrix. This is the matrix
29 : // that will be transformed to become upper triangular.
30 128 : SquareMatrix R{*this};
31 :
32 : // scale to reduce numerical issues
33 128 : const auto [min, max] = R.minmax();
34 128 : const auto scaleFactor = std::abs(min) > std::abs(max) ? min : max;
35 128 : R /= scaleFactor;
36 :
37 : using ColumnVector = Vector<ValueType_, Size_>;
38 128 : static ColumnVector w{};
39 753 : for (auto j = 0; j < Size_; ++j)
40 : {
41 : // Extract Size_-j rows of the j-th column as a Vector starting in row j.
42 625 : w.setBlock(Size_ - j, 1, j, j, j, 0, R);
43 1912 : for (auto k = 0; k < j; ++k)
44 : {
45 : // set unused values to zero
46 1287 : w.at_unsafe(k) = static_cast<ValueType_>(0);
47 : }
48 :
49 625 : const ValueType_ normx = w.norm();
50 : // Determines the sign of the j-th diagonal element of R.
51 625 : const ValueType_ sign =
52 625 : (R.at_unsafe(j, j) < static_cast<ValueType_>(0)) ? static_cast<ValueType_>(1) : static_cast<ValueType_>(-1);
53 625 : const ValueType_ u1 = R.at_unsafe(j, j) - sign * normx;
54 625 : const ValueType_ tau = -sign * u1 / normx; // Computes the parameter tau for the Householder transformation.
55 :
56 625 : w /= u1; // Computes the Householder vector w.
57 625 : w.at_unsafe(j) = static_cast<ValueType_>(1); // Sets the j-th row of w to 1 for convenience.
58 625 : const auto wView = MatrixColumnView<ValueType_, Size_, 1, true>(w, 0, j); // create view starting in j-th row
59 :
60 : // Update R using the Householder transformation
61 2537 : for (auto i = j; i < Size_; ++i) // cols
62 : {
63 : // R(j:end, i) = R(j:end, i) - tau * w * (w' * R(j:end, i));
64 1912 : const auto tau_dotRw = tau * (MatrixColumnView<ValueType_, Size_, Size_, IsRowMajor_>(R, i, j) * wView);
65 9300 : for (auto k = j; k < Size_; ++k) // rows
66 : {
67 7388 : R.at_unsafe(k, i) -= tau_dotRw * wView.at_unsafe(k - j);
68 : }
69 : }
70 :
71 : // Update Q using the Householder transformation
72 3824 : for (auto i = 0; i < Size_; ++i) // rows
73 : {
74 : // Q(i,j:end) = Q(i,j:end) - tau * (Q(i,j:end) * w) * w';
75 3199 : const auto tau_dotQw = tau * (MatrixRowView<ValueType_, Size_, Size_, IsRowMajor_>(Q, i, j) * wView);
76 13325 : for (auto k = j; k < Size_; ++k) // cols
77 : {
78 10126 : Q.at_unsafe(i, k) -= tau_dotQw * wView.at_unsafe(k - j);
79 : }
80 : }
81 : }
82 128 : auto triuR = TriangularMatrix<ValueType_, Size_, false, IsRowMajor_>{std::move(R)};
83 128 : triuR *= scaleFactor;
84 128 : return std::make_pair(std::move(Q), std::move(triuR));
85 128 : }
86 :
87 : // LLT decomposition
88 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
89 96 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::decomposeLLT() const
90 : -> tl::expected<TriangularMatrix<ValueType_, Size_, true, IsRowMajor_>, Errors>
91 : {
92 : // TODO(matthias): optimization - for small sizes (Size_ <= 10), leverage compile-time features
93 : // like constexpr computations where possible; add condition number checks for diagnostics
94 96 : if (hasStrictlyPositiveDiagonalElems()) // fail fast
95 : {
96 90 : if (isSymmetric()) // fail fast
97 : {
98 87 : TriangularMatrix<ValueType_, Size_, true, IsRowMajor_> L{};
99 444 : for (auto j = 0; j < Size_; ++j)
100 : {
101 357 : ValueType_ sum = this->at_unsafe(j, j);
102 972 : for (auto k = 0; k < j; ++k)
103 : {
104 1230 : sum -= math::pow<2>(L.at_unsafe(j, k));
105 : }
106 357 : L.at_unsafe(j, j) = std::sqrt(sum);
107 :
108 972 : for (auto i = j + 1; i < Size_; ++i)
109 : {
110 615 : sum = this->at_unsafe(i, j);
111 1210 : for (auto k = 0; k < j; ++k)
112 : {
113 595 : sum -= L.at_unsafe(i, k) * L.at_unsafe(j, k);
114 : }
115 615 : L.at_unsafe(i, j) = sum / L.at_unsafe(j, j);
116 : }
117 : }
118 87 : return std::move(L);
119 87 : }
120 3 : return tl::unexpected<Errors>{Errors::matrix_not_symmetric};
121 : }
122 6 : return tl::unexpected<Errors>{Errors::matrix_not_positive_definite};
123 : }
124 :
125 : // LDLT decomposition
126 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
127 21 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::decomposeLDLT() const
128 : -> tl::expected<std::pair<TriangularMatrix<ValueType_, Size_, true, IsRowMajor_>, DiagonalMatrix<ValueType_, Size_>>, Errors>
129 : {
130 21 : if (isSymmetric())
131 : {
132 19 : if (hasStrictlyPositiveDiagonalElems())
133 : {
134 13 : TriangularMatrix<ValueType_, Size_, true, IsRowMajor_> L{};
135 13 : DiagonalMatrix<ValueType_, Size_> D{};
136 71 : for (auto j = 0; j < Size_; ++j)
137 : {
138 58 : ValueType_ sum = this->at_unsafe(j, j);
139 169 : for (auto k = 0; k < j; ++k)
140 : {
141 111 : sum -= D.at_unsafe(k) * L.at_unsafe(j, k) * L.at_unsafe(j, k);
142 : }
143 58 : D.at_unsafe(j) = sum;
144 58 : L.at_unsafe(j, j) = static_cast<ValueType_>(1);
145 :
146 169 : for (auto i = j + 1; i < Size_; ++i)
147 : {
148 111 : sum = this->at_unsafe(i, j);
149 231 : for (auto k = 0; k < j; ++k)
150 : {
151 120 : sum -= D.at_unsafe(k) * L.at_unsafe(i, k) * L.at_unsafe(j, k);
152 : }
153 111 : L.at_unsafe(i, j) = sum / D.at_unsafe(j);
154 : }
155 : }
156 13 : return std::make_pair(std::move(L), std::move(D));
157 13 : }
158 6 : return tl::unexpected<Errors>{Errors::matrix_not_positive_definite};
159 : }
160 2 : return tl::unexpected<Errors>{Errors::matrix_not_symmetric};
161 : }
162 :
163 : // UDUT decomposition
164 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
165 52 : inline auto SquareMatrix<ValueType_, Size_, IsRowMajor_>::decomposeUDUT() const
166 : -> tl::expected<std::pair<TriangularMatrix<ValueType_, Size_, false, IsRowMajor_>, DiagonalMatrix<ValueType_, Size_>>, Errors>
167 : {
168 52 : if (isSymmetric())
169 : {
170 50 : const auto& P = *this;
171 50 : TriangularMatrix<ValueType_, Size_, false, IsRowMajor_> U{};
172 50 : DiagonalMatrix<ValueType_, Size_> D{};
173 252 : for (sint32 j = Size_ - 1; j >= 0; --j)
174 : {
175 743 : for (sint32 i = j; i >= 0; --i)
176 : {
177 541 : auto sigma = P.at_unsafe(i, j);
178 1196 : for (sint32 k = j + 1; k < Size_; ++k)
179 : {
180 655 : sigma -= U.at_unsafe(i, k) * D.at_unsafe(k) * U.at_unsafe(j, k);
181 : }
182 541 : if (i == j)
183 : {
184 207 : D.at_unsafe(j) = std::max(sigma, std::numeric_limits<ValueType_>::epsilon());
185 202 : U.at_unsafe(j, j) = static_cast<ValueType_>(1.0);
186 : }
187 : else
188 : {
189 339 : U.at_unsafe(i, j) = sigma / D.at_unsafe(j);
190 : }
191 : }
192 : }
193 50 : return std::make_pair(std::move(U), std::move(D));
194 50 : }
195 2 : return tl::unexpected<Errors>{Errors::matrix_not_symmetric};
196 : }
197 :
198 : } // namespace math
199 : } // namespace tracking
200 :
201 : #endif // C5DC5CCE_5C5B_4EAF_813E_3F6FEDDF09FA
|