lu.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/floats"
  10. "gonum.org/v1/gonum/lapack"
  11. "gonum.org/v1/gonum/lapack/lapack64"
  12. )
  13. const (
  14. badSliceLength = "mat: improper slice length"
  15. badLU = "mat: invalid LU factorization"
  16. )
  17. // LU is a type for creating and using the LU factorization of a matrix.
  18. type LU struct {
  19. lu *Dense
  20. pivot []int
  21. cond float64
  22. }
  23. // updateCond updates the stored condition number of the matrix. anorm is the
  24. // norm of the original matrix. If anorm is negative it will be estimated.
  25. func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) {
  26. n := lu.lu.mat.Cols
  27. work := getFloats(4*n, false)
  28. defer putFloats(work)
  29. iwork := getInts(n, false)
  30. defer putInts(iwork)
  31. if anorm < 0 {
  32. // This is an approximation. By the definition of a norm,
  33. // |AB| <= |A| |B|.
  34. // Since A = L*U, we get for the condition number κ that
  35. // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|,
  36. // so this will overestimate the condition number somewhat.
  37. // The norm of the original factorized matrix cannot be stored
  38. // because of update possibilities.
  39. u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper)
  40. l := lu.lu.asTriDense(n, blas.Unit, blas.Lower)
  41. unorm := lapack64.Lantr(norm, u.mat, work)
  42. lnorm := lapack64.Lantr(norm, l.mat, work)
  43. anorm = unorm * lnorm
  44. }
  45. v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork)
  46. lu.cond = 1 / v
  47. }
  48. // Factorize computes the LU factorization of the square matrix a and stores the
  49. // result. The LU decomposition will complete regardless of the singularity of a.
  50. //
  51. // The LU factorization is computed with pivoting, and so really the decomposition
  52. // is a PLU decomposition where P is a permutation matrix. The individual matrix
  53. // factors can be extracted from the factorization using the Permutation method
  54. // on Dense, and the LU.LTo and LU.UTo methods.
  55. func (lu *LU) Factorize(a Matrix) {
  56. lu.factorize(a, CondNorm)
  57. }
  58. func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) {
  59. r, c := a.Dims()
  60. if r != c {
  61. panic(ErrSquare)
  62. }
  63. if lu.lu == nil {
  64. lu.lu = NewDense(r, r, nil)
  65. } else {
  66. lu.lu.Reset()
  67. lu.lu.reuseAsNonZeroed(r, r)
  68. }
  69. lu.lu.Copy(a)
  70. if cap(lu.pivot) < r {
  71. lu.pivot = make([]int, r)
  72. }
  73. lu.pivot = lu.pivot[:r]
  74. work := getFloats(r, false)
  75. anorm := lapack64.Lange(norm, lu.lu.mat, work)
  76. putFloats(work)
  77. lapack64.Getrf(lu.lu.mat, lu.pivot)
  78. lu.updateCond(anorm, norm)
  79. }
  80. // isValid returns whether the receiver contains a factorization.
  81. func (lu *LU) isValid() bool {
  82. return lu.lu != nil && !lu.lu.IsEmpty()
  83. }
  84. // Cond returns the condition number for the factorized matrix.
  85. // Cond will panic if the receiver does not contain a factorization.
  86. func (lu *LU) Cond() float64 {
  87. if !lu.isValid() {
  88. panic(badLU)
  89. }
  90. return lu.cond
  91. }
  92. // Reset resets the factorization so that it can be reused as the receiver of a
  93. // dimensionally restricted operation.
  94. func (lu *LU) Reset() {
  95. if lu.lu != nil {
  96. lu.lu.Reset()
  97. }
  98. lu.pivot = lu.pivot[:0]
  99. }
  100. func (lu *LU) isZero() bool {
  101. return len(lu.pivot) == 0
  102. }
  103. // Det returns the determinant of the matrix that has been factorized. In many
  104. // expressions, using LogDet will be more numerically stable.
  105. // Det will panic if the receiver does not contain a factorization.
  106. func (lu *LU) Det() float64 {
  107. det, sign := lu.LogDet()
  108. return math.Exp(det) * sign
  109. }
  110. // LogDet returns the log of the determinant and the sign of the determinant
  111. // for the matrix that has been factorized. Numerical stability in product and
  112. // division expressions is generally improved by working in log space.
  113. // LogDet will panic if the receiver does not contain a factorization.
  114. func (lu *LU) LogDet() (det float64, sign float64) {
  115. if !lu.isValid() {
  116. panic(badLU)
  117. }
  118. _, n := lu.lu.Dims()
  119. logDiag := getFloats(n, false)
  120. defer putFloats(logDiag)
  121. sign = 1.0
  122. for i := 0; i < n; i++ {
  123. v := lu.lu.at(i, i)
  124. if v < 0 {
  125. sign *= -1
  126. }
  127. if lu.pivot[i] != i {
  128. sign *= -1
  129. }
  130. logDiag[i] = math.Log(math.Abs(v))
  131. }
  132. return floats.Sum(logDiag), sign
  133. }
  134. // Pivot returns pivot indices that enable the construction of the permutation
  135. // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be
  136. // allocated, otherwise the length of the input must be equal to the size of the
  137. // factorized matrix.
  138. // Pivot will panic if the receiver does not contain a factorization.
  139. func (lu *LU) Pivot(swaps []int) []int {
  140. if !lu.isValid() {
  141. panic(badLU)
  142. }
  143. _, n := lu.lu.Dims()
  144. if swaps == nil {
  145. swaps = make([]int, n)
  146. }
  147. if len(swaps) != n {
  148. panic(badSliceLength)
  149. }
  150. // Perform the inverse of the row swaps in order to find the final
  151. // row swap position.
  152. for i := range swaps {
  153. swaps[i] = i
  154. }
  155. for i := n - 1; i >= 0; i-- {
  156. v := lu.pivot[i]
  157. swaps[i], swaps[v] = swaps[v], swaps[i]
  158. }
  159. return swaps
  160. }
  161. // RankOne updates an LU factorization as if a rank-one update had been applied to
  162. // the original matrix A, storing the result into the receiver. That is, if in
  163. // the original LU decomposition P * L * U = A, in the updated decomposition
  164. // P * L * U = A + alpha * x * yᵀ.
  165. // RankOne will panic if orig does not contain a factorization.
  166. func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) {
  167. if !orig.isValid() {
  168. panic(badLU)
  169. }
  170. // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix
  171. // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng.
  172. // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf
  173. _, n := orig.lu.Dims()
  174. if r, c := x.Dims(); r != n || c != 1 {
  175. panic(ErrShape)
  176. }
  177. if r, c := y.Dims(); r != n || c != 1 {
  178. panic(ErrShape)
  179. }
  180. if orig != lu {
  181. if lu.isZero() {
  182. if cap(lu.pivot) < n {
  183. lu.pivot = make([]int, n)
  184. }
  185. lu.pivot = lu.pivot[:n]
  186. if lu.lu == nil {
  187. lu.lu = NewDense(n, n, nil)
  188. } else {
  189. lu.lu.reuseAsNonZeroed(n, n)
  190. }
  191. } else if len(lu.pivot) != n {
  192. panic(ErrShape)
  193. }
  194. copy(lu.pivot, orig.pivot)
  195. lu.lu.Copy(orig.lu)
  196. }
  197. xs := getFloats(n, false)
  198. defer putFloats(xs)
  199. ys := getFloats(n, false)
  200. defer putFloats(ys)
  201. for i := 0; i < n; i++ {
  202. xs[i] = x.AtVec(i)
  203. ys[i] = y.AtVec(i)
  204. }
  205. // Adjust for the pivoting in the LU factorization
  206. for i, v := range lu.pivot {
  207. xs[i], xs[v] = xs[v], xs[i]
  208. }
  209. lum := lu.lu.mat
  210. omega := alpha
  211. for j := 0; j < n; j++ {
  212. ujj := lum.Data[j*lum.Stride+j]
  213. ys[j] /= ujj
  214. theta := 1 + xs[j]*ys[j]*omega
  215. beta := omega * ys[j] / theta
  216. gamma := omega * xs[j]
  217. omega -= beta * gamma
  218. lum.Data[j*lum.Stride+j] *= theta
  219. for i := j + 1; i < n; i++ {
  220. xs[i] -= lum.Data[i*lum.Stride+j] * xs[j]
  221. tmp := ys[i]
  222. ys[i] -= lum.Data[j*lum.Stride+i] * ys[j]
  223. lum.Data[i*lum.Stride+j] += beta * xs[i]
  224. lum.Data[j*lum.Stride+i] += gamma * tmp
  225. }
  226. }
  227. lu.updateCond(-1, CondNorm)
  228. }
  229. // LTo extracts the lower triangular matrix from an LU factorization.
  230. //
  231. // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix.
  232. // When dst is non-empty, LTo will panic if dst is not n×n or not Lower.
  233. // LTo will also panic if the receiver does not contain a successful
  234. // factorization.
  235. func (lu *LU) LTo(dst *TriDense) *TriDense {
  236. if !lu.isValid() {
  237. panic(badLU)
  238. }
  239. _, n := lu.lu.Dims()
  240. if dst.IsEmpty() {
  241. dst.ReuseAsTri(n, Lower)
  242. } else {
  243. n2, kind := dst.Triangle()
  244. if n != n2 {
  245. panic(ErrShape)
  246. }
  247. if kind != Lower {
  248. panic(ErrTriangle)
  249. }
  250. }
  251. // Extract the lower triangular elements.
  252. for i := 0; i < n; i++ {
  253. for j := 0; j < i; j++ {
  254. dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
  255. }
  256. }
  257. // Set ones on the diagonal.
  258. for i := 0; i < n; i++ {
  259. dst.mat.Data[i*dst.mat.Stride+i] = 1
  260. }
  261. return dst
  262. }
  263. // UTo extracts the upper triangular matrix from an LU factorization.
  264. //
  265. // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix.
  266. // When dst is non-empty, UTo will panic if dst is not n×n or not Upper.
  267. // UTo will also panic if the receiver does not contain a successful
  268. // factorization.
  269. func (lu *LU) UTo(dst *TriDense) {
  270. if !lu.isValid() {
  271. panic(badLU)
  272. }
  273. _, n := lu.lu.Dims()
  274. if dst.IsEmpty() {
  275. dst.ReuseAsTri(n, Upper)
  276. } else {
  277. n2, kind := dst.Triangle()
  278. if n != n2 {
  279. panic(ErrShape)
  280. }
  281. if kind != Upper {
  282. panic(ErrTriangle)
  283. }
  284. }
  285. // Extract the upper triangular elements.
  286. for i := 0; i < n; i++ {
  287. for j := i; j < n; j++ {
  288. dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
  289. }
  290. }
  291. }
  292. // Permutation constructs an r×r permutation matrix with the given row swaps.
  293. // A permutation matrix has exactly one element equal to one in each row and column
  294. // and all other elements equal to zero. swaps[i] specifies the row with which
  295. // i will be swapped, which is equivalent to the non-zero column of row i.
  296. func (m *Dense) Permutation(r int, swaps []int) {
  297. m.reuseAsNonZeroed(r, r)
  298. for i := 0; i < r; i++ {
  299. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r])
  300. v := swaps[i]
  301. if v < 0 || v >= r {
  302. panic(ErrRowAccess)
  303. }
  304. m.mat.Data[i*m.mat.Stride+v] = 1
  305. }
  306. }
  307. // SolveTo solves a system of linear equations using the LU decomposition of a matrix.
  308. // It computes
  309. // A * X = B if trans == false
  310. // Aᵀ * X = B if trans == true
  311. // In both cases, A is represented in LU factorized form, and the matrix X is
  312. // stored into dst.
  313. //
  314. // If A is singular or near-singular a Condition error is returned. See
  315. // the documentation for Condition for more information.
  316. // SolveTo will panic if the receiver does not contain a factorization.
  317. func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error {
  318. if !lu.isValid() {
  319. panic(badLU)
  320. }
  321. _, n := lu.lu.Dims()
  322. br, bc := b.Dims()
  323. if br != n {
  324. panic(ErrShape)
  325. }
  326. // TODO(btracey): Should test the condition number instead of testing that
  327. // the determinant is exactly zero.
  328. if lu.Det() == 0 {
  329. return Condition(math.Inf(1))
  330. }
  331. dst.reuseAsNonZeroed(n, bc)
  332. bU, _ := untranspose(b)
  333. var restore func()
  334. if dst == bU {
  335. dst, restore = dst.isolatedWorkspace(bU)
  336. defer restore()
  337. } else if rm, ok := bU.(RawMatrixer); ok {
  338. dst.checkOverlap(rm.RawMatrix())
  339. }
  340. dst.Copy(b)
  341. t := blas.NoTrans
  342. if trans {
  343. t = blas.Trans
  344. }
  345. lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot)
  346. if lu.cond > ConditionTolerance {
  347. return Condition(lu.cond)
  348. }
  349. return nil
  350. }
  351. // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix.
  352. // It computes
  353. // A * x = b if trans == false
  354. // Aᵀ * x = b if trans == true
  355. // In both cases, A is represented in LU factorized form, and the vector x is
  356. // stored into dst.
  357. //
  358. // If A is singular or near-singular a Condition error is returned. See
  359. // the documentation for Condition for more information.
  360. // SolveVecTo will panic if the receiver does not contain a factorization.
  361. func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
  362. if !lu.isValid() {
  363. panic(badLU)
  364. }
  365. _, n := lu.lu.Dims()
  366. if br, bc := b.Dims(); br != n || bc != 1 {
  367. panic(ErrShape)
  368. }
  369. switch rv := b.(type) {
  370. default:
  371. dst.reuseAsNonZeroed(n)
  372. return lu.SolveTo(dst.asDense(), trans, b)
  373. case RawVectorer:
  374. if dst != b {
  375. dst.checkOverlap(rv.RawVector())
  376. }
  377. // TODO(btracey): Should test the condition number instead of testing that
  378. // the determinant is exactly zero.
  379. if lu.Det() == 0 {
  380. return Condition(math.Inf(1))
  381. }
  382. dst.reuseAsNonZeroed(n)
  383. var restore func()
  384. if dst == b {
  385. dst, restore = dst.isolatedWorkspace(b)
  386. defer restore()
  387. }
  388. dst.CopyVec(b)
  389. vMat := blas64.General{
  390. Rows: n,
  391. Cols: 1,
  392. Stride: dst.mat.Inc,
  393. Data: dst.mat.Data,
  394. }
  395. t := blas.NoTrans
  396. if trans {
  397. t = blas.Trans
  398. }
  399. lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
  400. if lu.cond > ConditionTolerance {
  401. return Condition(lu.cond)
  402. }
  403. return nil
  404. }
  405. }