Line data Source code
1 : #ifndef B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
2 : #define B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
3 :
4 : #include "math/linalg/rank1_update.h"
5 :
6 : #include "math/analysis/functions.h"
7 : #include "math/linalg/diagonal_matrix.hpp" // IWYU pragma: keep
8 : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
9 : #include "math/linalg/vector.hpp" // IWYU pragma: keep
10 : #include <cmath>
11 : #include <limits>
12 :
13 : namespace tracking
14 : {
15 : namespace math
16 : {
17 :
18 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
19 75 : inline void Rank1Update<ValueType_, Size_, IsRowMajor_>::run(TriangularMatrix<ValueType_, Size_, false, IsRowMajor_>& u,
20 : DiagonalMatrix<ValueType_, Size_>& d,
21 : ValueType_ c,
22 : Vector<ValueType_, Size_> x)
23 : {
24 : // based on
25 : // https://scicomp.stackexchange.com/questions/8323/computational-complexity-and-implementation-of-udu-modified-cholesky-rank-1-upda
26 :
27 75 : if (c == 0)
28 : return; // no update required
29 :
30 : // update or downdate ?
31 75 : const ValueType_ sign = (c > 0) ? +1 : -1;
32 75 : c = c * sign; // remove the sign
33 442 : for (int j = Size_ - 1; j >= 0; --j)
34 : {
35 : // Retrieve diagonal and update vector value
36 367 : ValueType_ djj = d.at_unsafe(j);
37 367 : ValueType_ ujx = x.at_unsafe(j);
38 :
39 : // Compute updated diagonal element with clipping to ensure PSD
40 367 : ValueType_ gamma = djj + sign * c * pow<2>(ujx);
41 367 : gamma = std::max(gamma, std::numeric_limits<ValueType_>::epsilon());
42 :
43 : // Update the diagonal
44 367 : d.at_unsafe(j) = gamma;
45 :
46 : // Compute scaling factors for U update
47 367 : ValueType_ beta = c / gamma;
48 367 : ValueType_ eta = beta * ujx;
49 :
50 : // Update the upper triangular matrix
51 1123 : for (int i = 0; i < j; ++i)
52 : {
53 756 : ValueType_ uij = u.at_unsafe(i, j);
54 756 : x.at_unsafe(i) -= uij * ujx; // Update x for future iterations
55 756 : u.at_unsafe(i, j) += sign * eta * x.at_unsafe(i); // Apply correction to U
56 : }
57 :
58 : // Update scaling factor for the next iteration
59 367 : c = beta * djj;
60 : }
61 : }
62 :
63 :
64 : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
65 3 : inline void Rank1Update<ValueType_, Size_, IsRowMajor_>::run(TriangularMatrix<ValueType_, Size_, true, IsRowMajor_>& l,
66 : DiagonalMatrix<ValueType_, Size_>& d,
67 : ValueType_ c,
68 : Vector<ValueType_, Size_> x)
69 : {
70 : // Methods for Modifying Matrix Factorizations in Mathematics of Computation
71 : // Gill, Golub, Murray and Saunders (1974)
72 : //
73 : // http://stanford.edu/group/SOL/papers/ggms74.pdf
74 :
75 3 : x *= sqrt(abs(c));
76 3 : c = (c > 0) ? static_cast<ValueType_>(1.0) : -static_cast<ValueType_>(1.0);
77 :
78 7 : ValueType_ dj_{};
79 2 : ValueType_ c_{};
80 2 : ValueType_ beta{};
81 :
82 : if (c > 0)
83 : {
84 5 : ValueType_ p{};
85 : c_ = static_cast<ValueType_>(1.0);
86 5 : for (auto j = 0; j < Size_; ++j)
87 : {
88 4 : p = x.at_unsafe(j);
89 4 : dj_ = d.at_unsafe(j);
90 4 : c = c_ + pow<2>(p) / dj_;
91 4 : d.at_unsafe(j) = dj_ * c / c_;
92 4 : beta = p / (dj_ * c);
93 10 : for (auto r = j + 1; r < Size_; ++r)
94 : {
95 6 : x.at_unsafe(r) -= p * l.at_unsafe(r, j);
96 6 : l.at_unsafe(r, j) += beta * x.at_unsafe(r);
97 : }
98 : c_ = c;
99 : }
100 : }
101 : else
102 : {
103 2 : const auto l_{l};
104 2 : decltype(x) p = Vector<ValueType_, Size_>{l.solve(x)};
105 2 : const auto dinv = static_cast<const DiagonalMatrix<ValueType_, Size_>&>(d).inverse();
106 : // ensure PSD
107 3 : c_ = std::max(1 - (p.transpose() * (dinv * p)).at_unsafe(0, 0), std::numeric_limits<ValueType_>::epsilon());
108 10 : for (auto j = Size_ - 1; j >= 0; --j)
109 : {
110 8 : dj_ = d.at_unsafe(j);
111 8 : c = c_ + pow<2>(p.at_unsafe(j)) / dj_;
112 8 : d.at_unsafe(j) = dj_ * c_ / c;
113 8 : beta = -p.at_unsafe(j) / (dj_ * c_);
114 8 : x.at_unsafe(j) = p.at_unsafe(j);
115 20 : for (auto r = j + 1; r < Size_; ++r)
116 : {
117 12 : l.at_unsafe(r, j) += beta * x.at_unsafe(r);
118 12 : x.at_unsafe(r) += p.at_unsafe(j) * l_.at_unsafe(r, j);
119 : }
120 8 : c_ = c;
121 : }
122 2 : }
123 3 : }
124 :
125 : } // namespace math
126 : } // namespace tracking
127 :
128 : #endif // B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
|