solve.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. "gonum.org/v1/gonum/lapack/lapack64"
  9. )
  10. // Solve solves the linear least squares problem
  11. // minimize over x |b - A*x|_2
  12. // where A is an m×n matrix A, b is a given m element vector and x is n element
  13. // solution vector. Solve assumes that A has full rank, that is
  14. // rank(A) = min(m,n)
  15. //
  16. // If m >= n, Solve finds the unique least squares solution of an overdetermined
  17. // system.
  18. //
  19. // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
  20. // this case Solve finds the unique solution of an underdetermined system that
  21. // minimizes |x|_2.
  22. //
  23. // Several right-hand side vectors b and solution vectors x can be handled in a
  24. // single call. Vectors b are stored in the columns of the m×k matrix B. Vectors
  25. // x will be stored in-place into the n×k receiver.
  26. //
  27. // If A does not have full rank, a Condition error is returned. See the
  28. // documentation for Condition for more information.
  29. func (m *Dense) Solve(a, b Matrix) error {
  30. ar, ac := a.Dims()
  31. br, bc := b.Dims()
  32. if ar != br {
  33. panic(ErrShape)
  34. }
  35. m.reuseAsNonZeroed(ac, bc)
  36. // TODO(btracey): Add special cases for SymDense, etc.
  37. aU, aTrans := untranspose(a)
  38. bU, bTrans := untranspose(b)
  39. switch rma := aU.(type) {
  40. case RawTriangular:
  41. side := blas.Left
  42. tA := blas.NoTrans
  43. if aTrans {
  44. tA = blas.Trans
  45. }
  46. switch rm := bU.(type) {
  47. case RawMatrixer:
  48. if m != bU || bTrans {
  49. if m == bU || m.checkOverlap(rm.RawMatrix()) {
  50. tmp := getWorkspace(br, bc, false)
  51. tmp.Copy(b)
  52. m.Copy(tmp)
  53. putWorkspace(tmp)
  54. break
  55. }
  56. m.Copy(b)
  57. }
  58. default:
  59. if m != bU {
  60. m.Copy(b)
  61. } else if bTrans {
  62. // m and b share data so Copy cannot be used directly.
  63. tmp := getWorkspace(br, bc, false)
  64. tmp.Copy(b)
  65. m.Copy(tmp)
  66. putWorkspace(tmp)
  67. }
  68. }
  69. rm := rma.RawTriangular()
  70. blas64.Trsm(side, tA, 1, rm, m.mat)
  71. work := getFloats(3*rm.N, false)
  72. iwork := getInts(rm.N, false)
  73. cond := lapack64.Trcon(CondNorm, rm, work, iwork)
  74. putFloats(work)
  75. putInts(iwork)
  76. if cond > ConditionTolerance {
  77. return Condition(cond)
  78. }
  79. return nil
  80. }
  81. switch {
  82. case ar == ac:
  83. if a == b {
  84. // x = I.
  85. if ar == 1 {
  86. m.mat.Data[0] = 1
  87. return nil
  88. }
  89. for i := 0; i < ar; i++ {
  90. v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
  91. zero(v)
  92. v[i] = 1
  93. }
  94. return nil
  95. }
  96. var lu LU
  97. lu.Factorize(a)
  98. return lu.SolveTo(m, false, b)
  99. case ar > ac:
  100. var qr QR
  101. qr.Factorize(a)
  102. return qr.SolveTo(m, false, b)
  103. default:
  104. var lq LQ
  105. lq.Factorize(a)
  106. return lq.SolveTo(m, false, b)
  107. }
  108. }
  109. // SolveVec solves the linear least squares problem
  110. // minimize over x |b - A*x|_2
  111. // where A is an m×n matrix A, b is a given m element vector and x is n element
  112. // solution vector. Solve assumes that A has full rank, that is
  113. // rank(A) = min(m,n)
  114. //
  115. // If m >= n, Solve finds the unique least squares solution of an overdetermined
  116. // system.
  117. //
  118. // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
  119. // this case Solve finds the unique solution of an underdetermined system that
  120. // minimizes |x|_2.
  121. //
  122. // The solution vector x will be stored in-place into the receiver.
  123. //
  124. // If A does not have full rank, a Condition error is returned. See the
  125. // documentation for Condition for more information.
  126. func (v *VecDense) SolveVec(a Matrix, b Vector) error {
  127. if _, bc := b.Dims(); bc != 1 {
  128. panic(ErrShape)
  129. }
  130. _, c := a.Dims()
  131. // The Solve implementation is non-trivial, so rather than duplicate the code,
  132. // instead recast the VecDenses as Dense and call the matrix code.
  133. if rv, ok := b.(RawVectorer); ok {
  134. bmat := rv.RawVector()
  135. if v != b {
  136. v.checkOverlap(bmat)
  137. }
  138. v.reuseAsNonZeroed(c)
  139. m := v.asDense()
  140. // We conditionally create bm as m when b and v are identical
  141. // to prevent the overlap detection code from identifying m
  142. // and bm as overlapping but not identical.
  143. bm := m
  144. if v != b {
  145. b := VecDense{mat: bmat}
  146. bm = b.asDense()
  147. }
  148. return m.Solve(a, bm)
  149. }
  150. v.reuseAsNonZeroed(c)
  151. m := v.asDense()
  152. return m.Solve(a, b)
  153. }