LCOV - code coverage report
Current view: top level - math/linalg - rank1_update.hpp (source / functions) Coverage Total Hit
Test: lcov.info Lines: 100.0 % 49 49
Test Date: 2026-04-26 21:52:20 Functions: 100.0 % 6 6
Legend: Lines: hit not hit

            Line data    Source code
       1              : #ifndef B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
       2              : #define B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
       3              : 
       4              : #include "math/linalg/rank1_update.h"
       5              : 
       6              : #include "math/analysis/functions.h"
       7              : #include "math/linalg/diagonal_matrix.hpp"   // IWYU pragma: keep
       8              : #include "math/linalg/triangular_matrix.hpp" // IWYU pragma: keep
       9              : #include "math/linalg/vector.hpp"            // IWYU pragma: keep
      10              : #include <cmath>
      11              : #include <limits>
      12              : 
      13              : namespace tracking
      14              : {
      15              : namespace math
      16              : {
      17              : 
      18              : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
      19           75 : inline void Rank1Update<ValueType_, Size_, IsRowMajor_>::run(TriangularMatrix<ValueType_, Size_, false, IsRowMajor_>& u,
      20              :                                                              DiagonalMatrix<ValueType_, Size_>&                       d,
      21              :                                                              ValueType_                                               c,
      22              :                                                              Vector<ValueType_, Size_>                                x)
      23              : {
      24              :   // based on
      25              :   // https://scicomp.stackexchange.com/questions/8323/computational-complexity-and-implementation-of-udu-modified-cholesky-rank-1-upda
      26              : 
      27           75 :   if (c == 0)
      28              :     return; // no update required
      29              : 
      30              :   // update or downdate ?
      31           75 :   const ValueType_ sign = (c > 0) ? +1 : -1;
      32           75 :   c                     = c * sign; // remove the sign
      33          442 :   for (int j = Size_ - 1; j >= 0; --j)
      34              :   {
      35              :     // Retrieve diagonal and update vector value
      36          367 :     ValueType_ djj = d.at_unsafe(j);
      37          367 :     ValueType_ ujx = x.at_unsafe(j);
      38              : 
      39              :     // Compute updated diagonal element with clipping to ensure PSD
      40          367 :     ValueType_ gamma = djj + sign * c * pow<2>(ujx);
      41          367 :     gamma            = std::max(gamma, std::numeric_limits<ValueType_>::epsilon());
      42              : 
      43              :     // Update the diagonal
      44          367 :     d.at_unsafe(j) = gamma;
      45              : 
      46              :     // Compute scaling factors for U update
      47          367 :     ValueType_ beta = c / gamma;
      48          367 :     ValueType_ eta  = beta * ujx;
      49              : 
      50              :     // Update the upper triangular matrix
      51         1123 :     for (int i = 0; i < j; ++i)
      52              :     {
      53          756 :       ValueType_ uij = u.at_unsafe(i, j);
      54          756 :       x.at_unsafe(i) -= uij * ujx;                      // Update x for future iterations
      55          756 :       u.at_unsafe(i, j) += sign * eta * x.at_unsafe(i); // Apply correction to U
      56              :     }
      57              : 
      58              :     // Update scaling factor for the next iteration
      59          367 :     c = beta * djj;
      60              :   }
      61              : }
      62              : 
      63              : 
      64              : template <typename ValueType_, sint32 Size_, bool IsRowMajor_>
      65            3 : inline void Rank1Update<ValueType_, Size_, IsRowMajor_>::run(TriangularMatrix<ValueType_, Size_, true, IsRowMajor_>& l,
      66              :                                                              DiagonalMatrix<ValueType_, Size_>&                      d,
      67              :                                                              ValueType_                                              c,
      68              :                                                              Vector<ValueType_, Size_>                               x)
      69              : {
      70              :   // Methods for Modifying Matrix Factorizations in Mathematics of Computation
      71              :   // Gill, Golub, Murray and Saunders (1974)
      72              :   //
      73              :   // http://stanford.edu/group/SOL/papers/ggms74.pdf
      74              : 
      75            3 :   x *= sqrt(abs(c));
      76            3 :   c = (c > 0) ? static_cast<ValueType_>(1.0) : -static_cast<ValueType_>(1.0);
      77              : 
      78            7 :   ValueType_ dj_{};
      79            2 :   ValueType_ c_{};
      80            2 :   ValueType_ beta{};
      81              : 
      82              :   if (c > 0)
      83              :   {
      84            5 :     ValueType_ p{};
      85              :     c_ = static_cast<ValueType_>(1.0);
      86            5 :     for (auto j = 0; j < Size_; ++j)
      87              :     {
      88            4 :       p              = x.at_unsafe(j);
      89            4 :       dj_            = d.at_unsafe(j);
      90            4 :       c              = c_ + pow<2>(p) / dj_;
      91            4 :       d.at_unsafe(j) = dj_ * c / c_;
      92            4 :       beta           = p / (dj_ * c);
      93           10 :       for (auto r = j + 1; r < Size_; ++r)
      94              :       {
      95            6 :         x.at_unsafe(r) -= p * l.at_unsafe(r, j);
      96            6 :         l.at_unsafe(r, j) += beta * x.at_unsafe(r);
      97              :       }
      98              :       c_ = c;
      99              :     }
     100              :   }
     101              :   else
     102              :   {
     103            2 :     const auto  l_{l};
     104            2 :     decltype(x) p    = Vector<ValueType_, Size_>{l.solve(x)};
     105            2 :     const auto  dinv = static_cast<const DiagonalMatrix<ValueType_, Size_>&>(d).inverse();
     106              :     // ensure PSD
     107            3 :     c_ = std::max(1 - (p.transpose() * (dinv * p)).at_unsafe(0, 0), std::numeric_limits<ValueType_>::epsilon());
     108           10 :     for (auto j = Size_ - 1; j >= 0; --j)
     109              :     {
     110            8 :       dj_            = d.at_unsafe(j);
     111            8 :       c              = c_ + pow<2>(p.at_unsafe(j)) / dj_;
     112            8 :       d.at_unsafe(j) = dj_ * c_ / c;
     113            8 :       beta           = -p.at_unsafe(j) / (dj_ * c_);
     114            8 :       x.at_unsafe(j) = p.at_unsafe(j);
     115           20 :       for (auto r = j + 1; r < Size_; ++r)
     116              :       {
     117           12 :         l.at_unsafe(r, j) += beta * x.at_unsafe(r);
     118           12 :         x.at_unsafe(r) += p.at_unsafe(j) * l_.at_unsafe(r, j);
     119              :       }
     120            8 :       c_ = c;
     121              :     }
     122            2 :   }
     123            3 : }
     124              : 
     125              : } // namespace math
     126              : } // namespace tracking
     127              : 
     128              : #endif // B84C34EE_5B17_49CF_8DC1_3BCF45A59A20
        

Generated by: LCOV version 2.0-1