svd.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas/blas64"
  7. "gonum.org/v1/gonum/lapack"
  8. "gonum.org/v1/gonum/lapack/lapack64"
  9. )
  10. const badRcond = "mat: invalid rcond value"
  11. // SVD is a type for creating and using the Singular Value Decomposition
  12. // of a matrix.
  13. type SVD struct {
  14. kind SVDKind
  15. s []float64
  16. u blas64.General
  17. vt blas64.General
  18. }
  19. // SVDKind specifies the treatment of singular vectors during an SVD
  20. // factorization.
  21. type SVDKind int
  22. const (
  23. // SVDNone specifies that no singular vectors should be computed during
  24. // the decomposition.
  25. SVDNone SVDKind = 0
  26. // SVDThinU specifies the thin decomposition for U should be computed.
  27. SVDThinU SVDKind = 1 << (iota - 1)
  28. // SVDFullU specifies the full decomposition for U should be computed.
  29. SVDFullU
  30. // SVDThinV specifies the thin decomposition for V should be computed.
  31. SVDThinV
  32. // SVDFullV specifies the full decomposition for V should be computed.
  33. SVDFullV
  34. // SVDThin is a convenience value for computing both thin vectors.
  35. SVDThin SVDKind = SVDThinU | SVDThinV
  36. // SVDFull is a convenience value for computing both full vectors.
  37. SVDFull SVDKind = SVDFullU | SVDFullV
  38. )
  39. // succFact returns whether the receiver contains a successful factorization.
  40. func (svd *SVD) succFact() bool {
  41. return len(svd.s) != 0
  42. }
  43. // Factorize computes the singular value decomposition (SVD) of the input matrix A.
  44. // The singular values of A are computed in all cases, while the singular
  45. // vectors are optionally computed depending on the input kind.
  46. //
  47. // The full singular value decomposition (kind == SVDFull) is a factorization
  48. // of an m×n matrix A of the form
  49. // A = U * Σ * Vᵀ
  50. // where Σ is an m×n diagonal matrix, U is an m×m orthogonal matrix, and V is an
  51. // n×n orthogonal matrix. The diagonal elements of Σ are the singular values of A.
  52. // The first min(m,n) columns of U and V are, respectively, the left and right
  53. // singular vectors of A.
  54. //
  55. // Significant storage space can be saved by using the thin representation of
  56. // the SVD (kind == SVDThin) instead of the full SVD, especially if
  57. // m >> n or m << n. The thin SVD finds
  58. // A = U~ * Σ * V~ᵀ
  59. // where U~ is of size m×min(m,n), Σ is a diagonal matrix of size min(m,n)×min(m,n)
  60. // and V~ is of size n×min(m,n).
  61. //
  62. // Factorize returns whether the decomposition succeeded. If the decomposition
  63. // failed, routines that require a successful factorization will panic.
  64. func (svd *SVD) Factorize(a Matrix, kind SVDKind) (ok bool) {
  65. // kill previous factorization
  66. svd.s = svd.s[:0]
  67. svd.kind = kind
  68. m, n := a.Dims()
  69. var jobU, jobVT lapack.SVDJob
  70. // TODO(btracey): This code should be modified to have the smaller
  71. // matrix written in-place into aCopy when the lapack/native/dgesvd
  72. // implementation is complete.
  73. switch {
  74. case kind&SVDFullU != 0:
  75. jobU = lapack.SVDAll
  76. svd.u = blas64.General{
  77. Rows: m,
  78. Cols: m,
  79. Stride: m,
  80. Data: use(svd.u.Data, m*m),
  81. }
  82. case kind&SVDThinU != 0:
  83. jobU = lapack.SVDStore
  84. svd.u = blas64.General{
  85. Rows: m,
  86. Cols: min(m, n),
  87. Stride: min(m, n),
  88. Data: use(svd.u.Data, m*min(m, n)),
  89. }
  90. default:
  91. jobU = lapack.SVDNone
  92. }
  93. switch {
  94. case kind&SVDFullV != 0:
  95. svd.vt = blas64.General{
  96. Rows: n,
  97. Cols: n,
  98. Stride: n,
  99. Data: use(svd.vt.Data, n*n),
  100. }
  101. jobVT = lapack.SVDAll
  102. case kind&SVDThinV != 0:
  103. svd.vt = blas64.General{
  104. Rows: min(m, n),
  105. Cols: n,
  106. Stride: n,
  107. Data: use(svd.vt.Data, min(m, n)*n),
  108. }
  109. jobVT = lapack.SVDStore
  110. default:
  111. jobVT = lapack.SVDNone
  112. }
  113. // A is destroyed on call, so copy the matrix.
  114. aCopy := DenseCopyOf(a)
  115. svd.kind = kind
  116. svd.s = use(svd.s, min(m, n))
  117. work := []float64{0}
  118. lapack64.Gesvd(jobU, jobVT, aCopy.mat, svd.u, svd.vt, svd.s, work, -1)
  119. work = getFloats(int(work[0]), false)
  120. ok = lapack64.Gesvd(jobU, jobVT, aCopy.mat, svd.u, svd.vt, svd.s, work, len(work))
  121. putFloats(work)
  122. if !ok {
  123. svd.kind = 0
  124. }
  125. return ok
  126. }
  127. // Kind returns the SVDKind of the decomposition. If no decomposition has been
  128. // computed, Kind returns -1.
  129. func (svd *SVD) Kind() SVDKind {
  130. if !svd.succFact() {
  131. return -1
  132. }
  133. return svd.kind
  134. }
  135. // Rank returns the rank of A based on the count of singular values greater than
  136. // rcond scaled by the largest singular value.
  137. // Rank will panic if the receiver does not contain a successful factorization or
  138. // rcond is negative.
  139. func (svd *SVD) Rank(rcond float64) int {
  140. if rcond < 0 {
  141. panic(badRcond)
  142. }
  143. if !svd.succFact() {
  144. panic(badFact)
  145. }
  146. s0 := svd.s[0]
  147. for i, v := range svd.s {
  148. if v <= rcond*s0 {
  149. return i
  150. }
  151. }
  152. return len(svd.s)
  153. }
  154. // Cond returns the 2-norm condition number for the factorized matrix. Cond will
  155. // panic if the receiver does not contain a successful factorization.
  156. func (svd *SVD) Cond() float64 {
  157. if !svd.succFact() {
  158. panic(badFact)
  159. }
  160. return svd.s[0] / svd.s[len(svd.s)-1]
  161. }
  162. // Values returns the singular values of the factorized matrix in descending order.
  163. //
  164. // If the input slice is non-nil, the values will be stored in-place into
  165. // the slice. In this case, the slice must have length min(m,n), and Values will
  166. // panic with ErrSliceLengthMismatch otherwise. If the input slice is nil, a new
  167. // slice of the appropriate length will be allocated and returned.
  168. //
  169. // Values will panic if the receiver does not contain a successful factorization.
  170. func (svd *SVD) Values(s []float64) []float64 {
  171. if !svd.succFact() {
  172. panic(badFact)
  173. }
  174. if s == nil {
  175. s = make([]float64, len(svd.s))
  176. }
  177. if len(s) != len(svd.s) {
  178. panic(ErrSliceLengthMismatch)
  179. }
  180. copy(s, svd.s)
  181. return s
  182. }
  183. // UTo extracts the matrix U from the singular value decomposition. The first
  184. // min(m,n) columns are the left singular vectors and correspond to the singular
  185. // values as returned from SVD.Values.
  186. //
  187. // If dst is empty, UTo will resize dst to be m×m if the full U was computed
  188. // and size m×min(m,n) if the thin U was computed. When dst is non-empty, then
  189. // UTo will panic if dst is not the appropriate size. UTo will also panic if
  190. // the receiver does not contain a successful factorization, or if U was
  191. // not computed during factorization.
  192. func (svd *SVD) UTo(dst *Dense) {
  193. if !svd.succFact() {
  194. panic(badFact)
  195. }
  196. kind := svd.kind
  197. if kind&SVDThinU == 0 && kind&SVDFullU == 0 {
  198. panic("svd: u not computed during factorization")
  199. }
  200. r := svd.u.Rows
  201. c := svd.u.Cols
  202. if dst.IsEmpty() {
  203. dst.ReuseAs(r, c)
  204. } else {
  205. r2, c2 := dst.Dims()
  206. if r != r2 || c != c2 {
  207. panic(ErrShape)
  208. }
  209. }
  210. tmp := &Dense{
  211. mat: svd.u,
  212. capRows: r,
  213. capCols: c,
  214. }
  215. dst.Copy(tmp)
  216. }
  217. // VTo extracts the matrix V from the singular value decomposition. The first
  218. // min(m,n) columns are the right singular vectors and correspond to the singular
  219. // values as returned from SVD.Values.
  220. //
  221. // If dst is empty, VTo will resize dst to be n×n if the full V was computed
  222. // and size n×min(m,n) if the thin V was computed. When dst is non-empty, then
  223. // VTo will panic if dst is not the appropriate size. VTo will also panic if
  224. // the receiver does not contain a successful factorization, or if V was
  225. // not computed during factorization.
  226. func (svd *SVD) VTo(dst *Dense) {
  227. if !svd.succFact() {
  228. panic(badFact)
  229. }
  230. kind := svd.kind
  231. if kind&SVDThinV == 0 && kind&SVDFullV == 0 {
  232. panic("svd: v not computed during factorization")
  233. }
  234. r := svd.vt.Rows
  235. c := svd.vt.Cols
  236. if dst.IsEmpty() {
  237. dst.ReuseAs(c, r)
  238. } else {
  239. r2, c2 := dst.Dims()
  240. if c != r2 || r != c2 {
  241. panic(ErrShape)
  242. }
  243. }
  244. tmp := &Dense{
  245. mat: svd.vt,
  246. capRows: r,
  247. capCols: c,
  248. }
  249. dst.Copy(tmp.T())
  250. }
  251. // SolveTo calculates the minimum-norm solution to a linear least squares problem
  252. // minimize over n-element vectors x: |b - A*x|_2 and |x|_2
  253. // where b is a given m-element vector, using the SVD of m×n matrix A stored in
  254. // the receiver. A may be rank-deficient, that is, the given effective rank can be
  255. // rank ≤ min(m,n)
  256. // The rank can be computed using SVD.Rank.
  257. //
  258. // Several right-hand side vectors b and solution vectors x can be handled in a
  259. // single call. Vectors b are stored in the columns of the m×k matrix B and the
  260. // resulting vectors x will be stored in the columns of dst. dst must be either
  261. // empty or have the size equal to n×k.
  262. //
  263. // The decomposition must have been factorized computing both the U and V
  264. // singular vectors.
  265. //
  266. // SolveTo returns the residuals calculated from the complete SVD. For this
  267. // value to be valid the factorization must have been performed with at least
  268. // SVDFullU.
  269. func (svd *SVD) SolveTo(dst *Dense, b Matrix, rank int) []float64 {
  270. if !svd.succFact() {
  271. panic(badFact)
  272. }
  273. if rank < 1 || len(svd.s) < rank {
  274. panic("svd: rank out of range")
  275. }
  276. kind := svd.kind
  277. if kind&SVDThinU == 0 && kind&SVDFullU == 0 {
  278. panic("svd: u not computed during factorization")
  279. }
  280. if kind&SVDThinV == 0 && kind&SVDFullV == 0 {
  281. panic("svd: v not computed during factorization")
  282. }
  283. u := Dense{
  284. mat: svd.u,
  285. capRows: svd.u.Rows,
  286. capCols: svd.u.Cols,
  287. }
  288. vt := Dense{
  289. mat: svd.vt,
  290. capRows: svd.vt.Rows,
  291. capCols: svd.vt.Cols,
  292. }
  293. s := svd.s[:rank]
  294. _, bc := b.Dims()
  295. c := getWorkspace(svd.u.Cols, bc, false)
  296. defer putWorkspace(c)
  297. c.Mul(u.T(), b)
  298. y := getWorkspace(rank, bc, false)
  299. defer putWorkspace(y)
  300. y.DivElem(c.slice(0, rank, 0, bc), repVector{vec: s, cols: bc})
  301. dst.Mul(vt.slice(0, rank, 0, svd.vt.Cols).T(), y)
  302. res := make([]float64, bc)
  303. if rank < svd.u.Cols {
  304. c = c.slice(len(s), svd.u.Cols, 0, bc)
  305. for j := range res {
  306. col := c.ColView(j)
  307. res[j] = Dot(col, col)
  308. }
  309. }
  310. return res
  311. }
  312. type repVector struct {
  313. vec []float64
  314. cols int
  315. }
  316. func (m repVector) Dims() (r, c int) { return len(m.vec), m.cols }
  317. func (m repVector) At(i, j int) float64 {
  318. if i < 0 || len(m.vec) <= i || j < 0 || m.cols <= j {
  319. panic(ErrIndexOutOfRange.string) // Panic with string to prevent mat.Error recovery.
  320. }
  321. return m.vec[i]
  322. }
  323. func (m repVector) T() Matrix { return Transpose{m} }
  324. // SolveVecTo calculates the minimum-norm solution to a linear least squares problem
  325. // minimize over n-element vectors x: |b - A*x|_2 and |x|_2
  326. // where b is a given m-element vector, using the SVD of m×n matrix A stored in
  327. // the receiver. A may be rank-deficient, that is, the given effective rank can be
  328. // rank ≤ min(m,n)
  329. // The rank can be computed using SVD.Rank.
  330. //
  331. // The resulting vector x will be stored in dst. dst must be either empty or
  332. // have length equal to n.
  333. //
  334. // The decomposition must have been factorized computing both the U and V
  335. // singular vectors.
  336. //
  337. // SolveVecTo returns the residuals calculated from the complete SVD. For this
  338. // value to be valid the factorization must have been performed with at least
  339. // SVDFullU.
  340. func (svd *SVD) SolveVecTo(dst *VecDense, b Vector, rank int) float64 {
  341. if !svd.succFact() {
  342. panic(badFact)
  343. }
  344. if rank < 1 || len(svd.s) < rank {
  345. panic("svd: rank out of range")
  346. }
  347. kind := svd.kind
  348. if kind&SVDThinU == 0 && kind&SVDFullU == 0 {
  349. panic("svd: u not computed during factorization")
  350. }
  351. if kind&SVDThinV == 0 && kind&SVDFullV == 0 {
  352. panic("svd: v not computed during factorization")
  353. }
  354. u := Dense{
  355. mat: svd.u,
  356. capRows: svd.u.Rows,
  357. capCols: svd.u.Cols,
  358. }
  359. vt := Dense{
  360. mat: svd.vt,
  361. capRows: svd.vt.Rows,
  362. capCols: svd.vt.Cols,
  363. }
  364. s := svd.s[:rank]
  365. c := getWorkspaceVec(svd.u.Cols, false)
  366. defer putWorkspaceVec(c)
  367. c.MulVec(u.T(), b)
  368. y := getWorkspaceVec(rank, false)
  369. defer putWorkspaceVec(y)
  370. y.DivElemVec(c.sliceVec(0, rank), NewVecDense(rank, s))
  371. dst.MulVec(vt.slice(0, rank, 0, svd.vt.Cols).T(), y)
  372. var res float64
  373. if rank < c.Len() {
  374. c = c.sliceVec(rank, c.Len())
  375. res = Dot(c, c)
  376. }
  377. return res
  378. }