lq.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/lapack"
  10. "gonum.org/v1/gonum/lapack/lapack64"
  11. )
  12. const badLQ = "mat: invalid LQ factorization"
  13. // LQ is a type for creating and using the LQ factorization of a matrix.
  14. type LQ struct {
  15. lq *Dense
  16. tau []float64
  17. cond float64
  18. }
  19. func (lq *LQ) updateCond(norm lapack.MatrixNorm) {
  20. // Since A = L*Q, and Q is orthogonal, we get for the condition number κ
  21. // κ(A) := |A| |A^-1| = |L*Q| |(L*Q)^-1| = |L| |Qᵀ * L^-1|
  22. // = |L| |L^-1| = κ(L),
  23. // where we used that fact that Q^-1 = Qᵀ. However, this assumes that
  24. // the matrix norm is invariant under orthogonal transformations which
  25. // is not the case for CondNorm. Hopefully the error is negligible: κ
  26. // is only a qualitative measure anyway.
  27. m := lq.lq.mat.Rows
  28. work := getFloats(3*m, false)
  29. iwork := getInts(m, false)
  30. l := lq.lq.asTriDense(m, blas.NonUnit, blas.Lower)
  31. v := lapack64.Trcon(norm, l.mat, work, iwork)
  32. lq.cond = 1 / v
  33. putFloats(work)
  34. putInts(iwork)
  35. }
  36. // Factorize computes the LQ factorization of an m×n matrix a where m <= n. The LQ
  37. // factorization always exists even if A is singular.
  38. //
  39. // The LQ decomposition is a factorization of the matrix A such that A = L * Q.
  40. // The matrix Q is an orthonormal n×n matrix, and L is an m×n lower triangular matrix.
  41. // L and Q can be extracted using the LTo and QTo methods.
  42. func (lq *LQ) Factorize(a Matrix) {
  43. lq.factorize(a, CondNorm)
  44. }
  45. func (lq *LQ) factorize(a Matrix, norm lapack.MatrixNorm) {
  46. m, n := a.Dims()
  47. if m > n {
  48. panic(ErrShape)
  49. }
  50. k := min(m, n)
  51. if lq.lq == nil {
  52. lq.lq = &Dense{}
  53. }
  54. lq.lq.CloneFrom(a)
  55. work := []float64{0}
  56. lq.tau = make([]float64, k)
  57. lapack64.Gelqf(lq.lq.mat, lq.tau, work, -1)
  58. work = getFloats(int(work[0]), false)
  59. lapack64.Gelqf(lq.lq.mat, lq.tau, work, len(work))
  60. putFloats(work)
  61. lq.updateCond(norm)
  62. }
  63. // isValid returns whether the receiver contains a factorization.
  64. func (lq *LQ) isValid() bool {
  65. return lq.lq != nil && !lq.lq.IsEmpty()
  66. }
  67. // Cond returns the condition number for the factorized matrix.
  68. // Cond will panic if the receiver does not contain a factorization.
  69. func (lq *LQ) Cond() float64 {
  70. if !lq.isValid() {
  71. panic(badLQ)
  72. }
  73. return lq.cond
  74. }
  75. // TODO(btracey): Add in the "Reduced" forms for extracting the m×m orthogonal
  76. // and upper triangular matrices.
  77. // LTo extracts the m×n lower trapezoidal matrix from a LQ decomposition.
  78. //
  79. // If dst is empty, LTo will resize dst to be r×c. When dst is
  80. // non-empty, LTo will panic if dst is not r×c. LTo will also panic
  81. // if the receiver does not contain a successful factorization.
  82. func (lq *LQ) LTo(dst *Dense) {
  83. if !lq.isValid() {
  84. panic(badLQ)
  85. }
  86. r, c := lq.lq.Dims()
  87. if dst.IsEmpty() {
  88. dst.ReuseAs(r, c)
  89. } else {
  90. r2, c2 := dst.Dims()
  91. if r != r2 || c != c2 {
  92. panic(ErrShape)
  93. }
  94. }
  95. // Disguise the LQ as a lower triangular.
  96. t := &TriDense{
  97. mat: blas64.Triangular{
  98. N: r,
  99. Stride: lq.lq.mat.Stride,
  100. Data: lq.lq.mat.Data,
  101. Uplo: blas.Lower,
  102. Diag: blas.NonUnit,
  103. },
  104. cap: lq.lq.capCols,
  105. }
  106. dst.Copy(t)
  107. if r == c {
  108. return
  109. }
  110. // Zero right of the triangular.
  111. for i := 0; i < r; i++ {
  112. zero(dst.mat.Data[i*dst.mat.Stride+r : i*dst.mat.Stride+c])
  113. }
  114. }
  115. // QTo extracts the n×n orthonormal matrix Q from an LQ decomposition.
  116. //
  117. // If dst is empty, QTo will resize dst to be c×c. When dst is
  118. // non-empty, QTo will panic if dst is not c×c. QTo will also panic
  119. // if the receiver does not contain a successful factorization.
  120. func (lq *LQ) QTo(dst *Dense) {
  121. if !lq.isValid() {
  122. panic(badLQ)
  123. }
  124. _, c := lq.lq.Dims()
  125. if dst.IsEmpty() {
  126. dst.ReuseAs(c, c)
  127. } else {
  128. r2, c2 := dst.Dims()
  129. if c != r2 || c != c2 {
  130. panic(ErrShape)
  131. }
  132. dst.Zero()
  133. }
  134. q := dst.mat
  135. // Set Q = I.
  136. ldq := q.Stride
  137. for i := 0; i < c; i++ {
  138. q.Data[i*ldq+i] = 1
  139. }
  140. // Construct Q from the elementary reflectors.
  141. work := []float64{0}
  142. lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, q, work, -1)
  143. work = getFloats(int(work[0]), false)
  144. lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, q, work, len(work))
  145. putFloats(work)
  146. }
  147. // SolveTo finds a minimum-norm solution to a system of linear equations defined
  148. // by the matrices A and b, where A is an m×n matrix represented in its LQ factorized
  149. // form. If A is singular or near-singular a Condition error is returned.
  150. // See the documentation for Condition for more information.
  151. //
  152. // The minimization problem solved depends on the input parameters.
  153. // If trans == false, find the minimum norm solution of A * X = B.
  154. // If trans == true, find X such that ||A*X - B||_2 is minimized.
  155. // The solution matrix, X, is stored in place into dst.
  156. // SolveTo will panic if the receiver does not contain a factorization.
  157. func (lq *LQ) SolveTo(dst *Dense, trans bool, b Matrix) error {
  158. if !lq.isValid() {
  159. panic(badLQ)
  160. }
  161. r, c := lq.lq.Dims()
  162. br, bc := b.Dims()
  163. // The LQ solve algorithm stores the result in-place into the right hand side.
  164. // The storage for the answer must be large enough to hold both b and x.
  165. // However, this method's receiver must be the size of x. Copy b, and then
  166. // copy the result into x at the end.
  167. if trans {
  168. if c != br {
  169. panic(ErrShape)
  170. }
  171. dst.reuseAsNonZeroed(r, bc)
  172. } else {
  173. if r != br {
  174. panic(ErrShape)
  175. }
  176. dst.reuseAsNonZeroed(c, bc)
  177. }
  178. // Do not need to worry about overlap between x and b because w has its own
  179. // independent storage.
  180. w := getWorkspace(max(r, c), bc, false)
  181. w.Copy(b)
  182. t := lq.lq.asTriDense(lq.lq.mat.Rows, blas.NonUnit, blas.Lower).mat
  183. if trans {
  184. work := []float64{0}
  185. lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, w.mat, work, -1)
  186. work = getFloats(int(work[0]), false)
  187. lapack64.Ormlq(blas.Left, blas.NoTrans, lq.lq.mat, lq.tau, w.mat, work, len(work))
  188. putFloats(work)
  189. ok := lapack64.Trtrs(blas.Trans, t, w.mat)
  190. if !ok {
  191. return Condition(math.Inf(1))
  192. }
  193. } else {
  194. ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
  195. if !ok {
  196. return Condition(math.Inf(1))
  197. }
  198. for i := r; i < c; i++ {
  199. zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
  200. }
  201. work := []float64{0}
  202. lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, w.mat, work, -1)
  203. work = getFloats(int(work[0]), false)
  204. lapack64.Ormlq(blas.Left, blas.Trans, lq.lq.mat, lq.tau, w.mat, work, len(work))
  205. putFloats(work)
  206. }
  207. // x was set above to be the correct size for the result.
  208. dst.Copy(w)
  209. putWorkspace(w)
  210. if lq.cond > ConditionTolerance {
  211. return Condition(lq.cond)
  212. }
  213. return nil
  214. }
  215. // SolveVecTo finds a minimum-norm solution to a system of linear equations.
  216. // See LQ.SolveTo for the full documentation.
  217. // SolveToVec will panic if the receiver does not contain a factorization.
  218. func (lq *LQ) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
  219. if !lq.isValid() {
  220. panic(badLQ)
  221. }
  222. r, c := lq.lq.Dims()
  223. if _, bc := b.Dims(); bc != 1 {
  224. panic(ErrShape)
  225. }
  226. // The Solve implementation is non-trivial, so rather than duplicate the code,
  227. // instead recast the VecDenses as Dense and call the matrix code.
  228. bm := Matrix(b)
  229. if rv, ok := b.(RawVectorer); ok {
  230. bmat := rv.RawVector()
  231. if dst != b {
  232. dst.checkOverlap(bmat)
  233. }
  234. b := VecDense{mat: bmat}
  235. bm = b.asDense()
  236. }
  237. if trans {
  238. dst.reuseAsNonZeroed(r)
  239. } else {
  240. dst.reuseAsNonZeroed(c)
  241. }
  242. return lq.SolveTo(dst.asDense(), trans, bm)
  243. }