qr.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/lapack"
  10. "gonum.org/v1/gonum/lapack/lapack64"
  11. )
  12. const badQR = "mat: invalid QR factorization"
  13. // QR is a type for creating and using the QR factorization of a matrix.
  14. type QR struct {
  15. qr *Dense
  16. tau []float64
  17. cond float64
  18. }
  19. func (qr *QR) updateCond(norm lapack.MatrixNorm) {
  20. // Since A = Q*R, and Q is orthogonal, we get for the condition number κ
  21. // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ|
  22. // = |R| |R^-1| = κ(R),
  23. // where we used that fact that Q^-1 = Qᵀ. However, this assumes that
  24. // the matrix norm is invariant under orthogonal transformations which
  25. // is not the case for CondNorm. Hopefully the error is negligible: κ
  26. // is only a qualitative measure anyway.
  27. n := qr.qr.mat.Cols
  28. work := getFloats(3*n, false)
  29. iwork := getInts(n, false)
  30. r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper)
  31. v := lapack64.Trcon(norm, r.mat, work, iwork)
  32. putFloats(work)
  33. putInts(iwork)
  34. qr.cond = 1 / v
  35. }
  36. // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR
  37. // factorization always exists even if A is singular.
  38. //
  39. // The QR decomposition is a factorization of the matrix A such that A = Q * R.
  40. // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix.
  41. // Q and R can be extracted using the QTo and RTo methods.
  42. func (qr *QR) Factorize(a Matrix) {
  43. qr.factorize(a, CondNorm)
  44. }
  45. func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
  46. m, n := a.Dims()
  47. if m < n {
  48. panic(ErrShape)
  49. }
  50. k := min(m, n)
  51. if qr.qr == nil {
  52. qr.qr = &Dense{}
  53. }
  54. qr.qr.CloneFrom(a)
  55. work := []float64{0}
  56. qr.tau = make([]float64, k)
  57. lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
  58. work = getFloats(int(work[0]), false)
  59. lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
  60. putFloats(work)
  61. qr.updateCond(norm)
  62. }
  63. // isValid returns whether the receiver contains a factorization.
  64. func (qr *QR) isValid() bool {
  65. return qr.qr != nil && !qr.qr.IsEmpty()
  66. }
  67. // Cond returns the condition number for the factorized matrix.
  68. // Cond will panic if the receiver does not contain a factorization.
  69. func (qr *QR) Cond() float64 {
  70. if !qr.isValid() {
  71. panic(badQR)
  72. }
  73. return qr.cond
  74. }
  75. // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal
  76. // and upper triangular matrices.
  77. // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition.
  78. //
  79. // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty,
  80. // RTo will panic if dst is not r×c. RTo will also panic if the receiver
  81. // does not contain a successful factorization.
  82. func (qr *QR) RTo(dst *Dense) {
  83. if !qr.isValid() {
  84. panic(badQR)
  85. }
  86. r, c := qr.qr.Dims()
  87. if dst.IsEmpty() {
  88. dst.ReuseAs(r, c)
  89. } else {
  90. r2, c2 := dst.Dims()
  91. if c != r2 || c != c2 {
  92. panic(ErrShape)
  93. }
  94. }
  95. // Disguise the QR as an upper triangular
  96. t := &TriDense{
  97. mat: blas64.Triangular{
  98. N: c,
  99. Stride: qr.qr.mat.Stride,
  100. Data: qr.qr.mat.Data,
  101. Uplo: blas.Upper,
  102. Diag: blas.NonUnit,
  103. },
  104. cap: qr.qr.capCols,
  105. }
  106. dst.Copy(t)
  107. // Zero below the triangular.
  108. for i := r; i < c; i++ {
  109. zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c])
  110. }
  111. }
  112. // QTo extracts the r×r orthonormal matrix Q from a QR decomposition.
  113. //
  114. // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty,
  115. // QTo will panic if dst is not r×r. QTo will also panic if the receiver
  116. // does not contain a successful factorization.
  117. func (qr *QR) QTo(dst *Dense) {
  118. if !qr.isValid() {
  119. panic(badQR)
  120. }
  121. r, _ := qr.qr.Dims()
  122. if dst.IsEmpty() {
  123. dst.ReuseAs(r, r)
  124. } else {
  125. r2, c2 := dst.Dims()
  126. if r != r2 || r != c2 {
  127. panic(ErrShape)
  128. }
  129. dst.Zero()
  130. }
  131. // Set Q = I.
  132. for i := 0; i < r*r; i += r + 1 {
  133. dst.mat.Data[i] = 1
  134. }
  135. // Construct Q from the elementary reflectors.
  136. work := []float64{0}
  137. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1)
  138. work = getFloats(int(work[0]), false)
  139. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work))
  140. putFloats(work)
  141. }
  142. // SolveTo finds a minimum-norm solution to a system of linear equations defined
  143. // by the matrices A and b, where A is an m×n matrix represented in its QR factorized
  144. // form. If A is singular or near-singular a Condition error is returned.
  145. // See the documentation for Condition for more information.
  146. //
  147. // The minimization problem solved depends on the input parameters.
  148. // If trans == false, find X such that ||A*X - B||_2 is minimized.
  149. // If trans == true, find the minimum norm solution of Aᵀ * X = B.
  150. // The solution matrix, X, is stored in place into dst.
  151. // SolveTo will panic if the receiver does not contain a factorization.
  152. func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
  153. if !qr.isValid() {
  154. panic(badQR)
  155. }
  156. r, c := qr.qr.Dims()
  157. br, bc := b.Dims()
  158. // The QR solve algorithm stores the result in-place into the right hand side.
  159. // The storage for the answer must be large enough to hold both b and x.
  160. // However, this method's receiver must be the size of x. Copy b, and then
  161. // copy the result into m at the end.
  162. if trans {
  163. if c != br {
  164. panic(ErrShape)
  165. }
  166. dst.reuseAsNonZeroed(r, bc)
  167. } else {
  168. if r != br {
  169. panic(ErrShape)
  170. }
  171. dst.reuseAsNonZeroed(c, bc)
  172. }
  173. // Do not need to worry about overlap between m and b because x has its own
  174. // independent storage.
  175. w := getWorkspace(max(r, c), bc, false)
  176. w.Copy(b)
  177. t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat
  178. if trans {
  179. ok := lapack64.Trtrs(blas.Trans, t, w.mat)
  180. if !ok {
  181. return Condition(math.Inf(1))
  182. }
  183. for i := c; i < r; i++ {
  184. zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
  185. }
  186. work := []float64{0}
  187. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1)
  188. work = getFloats(int(work[0]), false)
  189. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work))
  190. putFloats(work)
  191. } else {
  192. work := []float64{0}
  193. lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1)
  194. work = getFloats(int(work[0]), false)
  195. lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work))
  196. putFloats(work)
  197. ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
  198. if !ok {
  199. return Condition(math.Inf(1))
  200. }
  201. }
  202. // X was set above to be the correct size for the result.
  203. dst.Copy(w)
  204. putWorkspace(w)
  205. if qr.cond > ConditionTolerance {
  206. return Condition(qr.cond)
  207. }
  208. return nil
  209. }
  210. // SolveVecTo finds a minimum-norm solution to a system of linear equations,
  211. // Ax = b.
  212. // See QR.SolveTo for the full documentation.
  213. // SolveVecTo will panic if the receiver does not contain a factorization.
  214. func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
  215. if !qr.isValid() {
  216. panic(badQR)
  217. }
  218. r, c := qr.qr.Dims()
  219. if _, bc := b.Dims(); bc != 1 {
  220. panic(ErrShape)
  221. }
  222. // The Solve implementation is non-trivial, so rather than duplicate the code,
  223. // instead recast the VecDenses as Dense and call the matrix code.
  224. bm := Matrix(b)
  225. if rv, ok := b.(RawVectorer); ok {
  226. bmat := rv.RawVector()
  227. if dst != b {
  228. dst.checkOverlap(bmat)
  229. }
  230. b := VecDense{mat: bmat}
  231. bm = b.asDense()
  232. }
  233. if trans {
  234. dst.reuseAsNonZeroed(r)
  235. } else {
  236. dst.reuseAsNonZeroed(c)
  237. }
  238. return qr.SolveTo(dst.asDense(), trans, bm)
  239. }