matrix.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/floats/scalar"
  10. "gonum.org/v1/gonum/lapack"
  11. "gonum.org/v1/gonum/lapack/lapack64"
  12. )
  13. // Matrix is the basic matrix interface type.
  14. type Matrix interface {
  15. // Dims returns the dimensions of a Matrix.
  16. Dims() (r, c int)
  17. // At returns the value of a matrix element at row i, column j.
  18. // It will panic if i or j are out of bounds for the matrix.
  19. At(i, j int) float64
  20. // T returns the transpose of the Matrix. Whether T returns a copy of the
  21. // underlying data is implementation dependent.
  22. // This method may be implemented using the Transpose type, which
  23. // provides an implicit matrix transpose.
  24. T() Matrix
  25. }
  26. // allMatrix represents the extra set of methods that all mat Matrix types
  27. // should satisfy. This is used to enforce compile-time consistency between the
  28. // Dense types, especially helpful when adding new features.
  29. type allMatrix interface {
  30. Reseter
  31. IsEmpty() bool
  32. Zero()
  33. }
  34. // denseMatrix represents the extra set of methods that all Dense Matrix types
  35. // should satisfy. This is used to enforce compile-time consistency between the
  36. // Dense types, especially helpful when adding new features.
  37. type denseMatrix interface {
  38. DiagView() Diagonal
  39. Tracer
  40. }
  41. var (
  42. _ Matrix = Transpose{}
  43. _ Untransposer = Transpose{}
  44. )
  45. // Transpose is a type for performing an implicit matrix transpose. It implements
  46. // the Matrix interface, returning values from the transpose of the matrix within.
  47. type Transpose struct {
  48. Matrix Matrix
  49. }
  50. // At returns the value of the element at row i and column j of the transposed
  51. // matrix, that is, row j and column i of the Matrix field.
  52. func (t Transpose) At(i, j int) float64 {
  53. return t.Matrix.At(j, i)
  54. }
  55. // Dims returns the dimensions of the transposed matrix. The number of rows returned
  56. // is the number of columns in the Matrix field, and the number of columns is
  57. // the number of rows in the Matrix field.
  58. func (t Transpose) Dims() (r, c int) {
  59. c, r = t.Matrix.Dims()
  60. return r, c
  61. }
  62. // T performs an implicit transpose by returning the Matrix field.
  63. func (t Transpose) T() Matrix {
  64. return t.Matrix
  65. }
  66. // Untranspose returns the Matrix field.
  67. func (t Transpose) Untranspose() Matrix {
  68. return t.Matrix
  69. }
  70. // Untransposer is a type that can undo an implicit transpose.
  71. type Untransposer interface {
  72. // Note: This interface is needed to unify all of the Transpose types. In
  73. // the mat methods, we need to test if the Matrix has been implicitly
  74. // transposed. If this is checked by testing for the specific Transpose type
  75. // then the behavior will be different if the user uses T() or TTri() for a
  76. // triangular matrix.
  77. // Untranspose returns the underlying Matrix stored for the implicit transpose.
  78. Untranspose() Matrix
  79. }
  80. // UntransposeBander is a type that can undo an implicit band transpose.
  81. type UntransposeBander interface {
  82. // Untranspose returns the underlying Banded stored for the implicit transpose.
  83. UntransposeBand() Banded
  84. }
  85. // UntransposeTrier is a type that can undo an implicit triangular transpose.
  86. type UntransposeTrier interface {
  87. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  88. UntransposeTri() Triangular
  89. }
  90. // UntransposeTriBander is a type that can undo an implicit triangular banded
  91. // transpose.
  92. type UntransposeTriBander interface {
  93. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  94. UntransposeTriBand() TriBanded
  95. }
  96. // Mutable is a matrix interface type that allows elements to be altered.
  97. type Mutable interface {
  98. // Set alters the matrix element at row i, column j to v.
  99. // It will panic if i or j are out of bounds for the matrix.
  100. Set(i, j int, v float64)
  101. Matrix
  102. }
  103. // A RowViewer can return a Vector reflecting a row that is backed by the matrix
  104. // data. The Vector returned will have length equal to the number of columns.
  105. type RowViewer interface {
  106. RowView(i int) Vector
  107. }
  108. // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
  109. // data.
  110. type RawRowViewer interface {
  111. RawRowView(i int) []float64
  112. }
  113. // A ColViewer can return a Vector reflecting a column that is backed by the matrix
  114. // data. The Vector returned will have length equal to the number of rows.
  115. type ColViewer interface {
  116. ColView(j int) Vector
  117. }
  118. // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
  119. // data.
  120. type RawColViewer interface {
  121. RawColView(j int) []float64
  122. }
  123. // A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the
  124. // receiver. The clone operation does not make any restriction on shape and will not cause
  125. // shadowing.
  126. type ClonerFrom interface {
  127. CloneFrom(a Matrix)
  128. }
  129. // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
  130. // restricted operation. This is commonly used when the matrix is being used as a workspace
  131. // or temporary matrix.
  132. //
  133. // If the matrix is a view, using Reset may result in data corruption in elements outside
  134. // the view. Similarly, if the matrix shares backing data with another variable, using
  135. // Reset may lead to unexpected changes in data values.
  136. type Reseter interface {
  137. Reset()
  138. }
  139. // A Copier can make a copy of elements of a into the receiver. The submatrix copied
  140. // starts at row and column 0 and has dimensions equal to the minimum dimensions of
  141. // the two matrices. The number of row and columns copied is returned.
  142. // Copy will copy from a source that aliases the receiver unless the source is transposed;
  143. // an aliasing transpose copy will panic with the exception for a special case when
  144. // the source data has a unitary increment or stride.
  145. type Copier interface {
  146. Copy(a Matrix) (r, c int)
  147. }
  148. // A Grower can grow the size of the represented matrix by the given number of rows and columns.
  149. // Growing beyond the size given by the Caps method will result in the allocation of a new
  150. // matrix and copying of the elements. If Grow is called with negative increments it will
  151. // panic with ErrIndexOutOfRange.
  152. type Grower interface {
  153. Caps() (r, c int)
  154. Grow(r, c int) Matrix
  155. }
  156. // A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
  157. // k2.
  158. type BandWidther interface {
  159. BandWidth() (k1, k2 int)
  160. }
  161. // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
  162. // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
  163. type RawMatrixSetter interface {
  164. SetRawMatrix(a blas64.General)
  165. }
  166. // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
  167. // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
  168. type RawMatrixer interface {
  169. RawMatrix() blas64.General
  170. }
  171. // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
  172. // slice will be reflected in the original matrix, changes to the Inc field will not.
  173. type RawVectorer interface {
  174. RawVector() blas64.Vector
  175. }
  176. // A NonZeroDoer can call a function for each non-zero element of the receiver.
  177. // The parameters of the function are the element indices and its value.
  178. type NonZeroDoer interface {
  179. DoNonZero(func(i, j int, v float64))
  180. }
  181. // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
  182. // The parameters of the function are the element indices and its value.
  183. type RowNonZeroDoer interface {
  184. DoRowNonZero(i int, fn func(i, j int, v float64))
  185. }
  186. // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
  187. // The parameters of the function are the element indices and its value.
  188. type ColNonZeroDoer interface {
  189. DoColNonZero(j int, fn func(i, j int, v float64))
  190. }
  191. // untranspose untransposes a matrix if applicable. If a is an Untransposer, then
  192. // untranspose returns the underlying matrix and true. If it is not, then it returns
  193. // the input matrix and false.
  194. func untranspose(a Matrix) (Matrix, bool) {
  195. if ut, ok := a.(Untransposer); ok {
  196. return ut.Untranspose(), true
  197. }
  198. return a, false
  199. }
  200. // untransposeExtract returns an untransposed matrix in a built-in matrix type.
  201. //
  202. // The untransposed matrix is returned unaltered if it is a built-in matrix type.
  203. // Otherwise, if it implements a Raw method, an appropriate built-in type value
  204. // is returned holding the raw matrix value of the input. If neither of these
  205. // is possible, the untransposed matrix is returned.
  206. func untransposeExtract(a Matrix) (Matrix, bool) {
  207. ut, trans := untranspose(a)
  208. switch m := ut.(type) {
  209. case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense:
  210. return m, trans
  211. // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
  212. case RawSymBander:
  213. rsb := m.RawSymBand()
  214. if rsb.Uplo != blas.Upper {
  215. return ut, trans
  216. }
  217. var sb SymBandDense
  218. sb.SetRawSymBand(rsb)
  219. return &sb, trans
  220. case RawTriBander:
  221. rtb := m.RawTriBand()
  222. if rtb.Diag == blas.Unit {
  223. return ut, trans
  224. }
  225. var tb TriBandDense
  226. tb.SetRawTriBand(rtb)
  227. return &tb, trans
  228. case RawBander:
  229. var b BandDense
  230. b.SetRawBand(m.RawBand())
  231. return &b, trans
  232. case RawTriangular:
  233. rt := m.RawTriangular()
  234. if rt.Diag == blas.Unit {
  235. return ut, trans
  236. }
  237. var t TriDense
  238. t.SetRawTriangular(rt)
  239. return &t, trans
  240. case RawSymmetricer:
  241. rs := m.RawSymmetric()
  242. if rs.Uplo != blas.Upper {
  243. return ut, trans
  244. }
  245. var s SymDense
  246. s.SetRawSymmetric(rs)
  247. return &s, trans
  248. case RawMatrixer:
  249. var d Dense
  250. d.SetRawMatrix(m.RawMatrix())
  251. return &d, trans
  252. case RawVectorer:
  253. var v VecDense
  254. v.SetRawVector(m.RawVector())
  255. return &v, trans
  256. default:
  257. return ut, trans
  258. }
  259. }
  260. // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
  261. // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
  262. // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
  263. // Col copies the elements in the jth column of the matrix into the slice dst.
  264. // The length of the provided slice must equal the number of rows, unless the
  265. // slice is nil in which case a new slice is first allocated.
  266. func Col(dst []float64, j int, a Matrix) []float64 {
  267. r, c := a.Dims()
  268. if j < 0 || j >= c {
  269. panic(ErrColAccess)
  270. }
  271. if dst == nil {
  272. dst = make([]float64, r)
  273. } else {
  274. if len(dst) != r {
  275. panic(ErrColLength)
  276. }
  277. }
  278. aU, aTrans := untranspose(a)
  279. if rm, ok := aU.(RawMatrixer); ok {
  280. m := rm.RawMatrix()
  281. if aTrans {
  282. copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
  283. return dst
  284. }
  285. blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
  286. blas64.Vector{N: r, Inc: 1, Data: dst},
  287. )
  288. return dst
  289. }
  290. for i := 0; i < r; i++ {
  291. dst[i] = a.At(i, j)
  292. }
  293. return dst
  294. }
  295. // Row copies the elements in the ith row of the matrix into the slice dst.
  296. // The length of the provided slice must equal the number of columns, unless the
  297. // slice is nil in which case a new slice is first allocated.
  298. func Row(dst []float64, i int, a Matrix) []float64 {
  299. r, c := a.Dims()
  300. if i < 0 || i >= r {
  301. panic(ErrColAccess)
  302. }
  303. if dst == nil {
  304. dst = make([]float64, c)
  305. } else {
  306. if len(dst) != c {
  307. panic(ErrRowLength)
  308. }
  309. }
  310. aU, aTrans := untranspose(a)
  311. if rm, ok := aU.(RawMatrixer); ok {
  312. m := rm.RawMatrix()
  313. if aTrans {
  314. blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
  315. blas64.Vector{N: c, Inc: 1, Data: dst},
  316. )
  317. return dst
  318. }
  319. copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
  320. return dst
  321. }
  322. for j := 0; j < c; j++ {
  323. dst[j] = a.At(i, j)
  324. }
  325. return dst
  326. }
  327. // Cond returns the condition number of the given matrix under the given norm.
  328. // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
  329. // Cond will panic with matrix.ErrShape if the matrix has zero size.
  330. //
  331. // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
  332. // is inaccurate, although is typically the right order of magnitude. See
  333. // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
  334. // change with the resolution of this bug, the result from Cond will match the
  335. // condition number used internally.
  336. func Cond(a Matrix, norm float64) float64 {
  337. m, n := a.Dims()
  338. if m == 0 || n == 0 {
  339. panic(ErrShape)
  340. }
  341. var lnorm lapack.MatrixNorm
  342. switch norm {
  343. default:
  344. panic("mat: bad norm value")
  345. case 1:
  346. lnorm = lapack.MaxColumnSum
  347. case 2:
  348. var svd SVD
  349. ok := svd.Factorize(a, SVDNone)
  350. if !ok {
  351. return math.Inf(1)
  352. }
  353. return svd.Cond()
  354. case math.Inf(1):
  355. lnorm = lapack.MaxRowSum
  356. }
  357. if m == n {
  358. // Use the LU decomposition to compute the condition number.
  359. var lu LU
  360. lu.factorize(a, lnorm)
  361. return lu.Cond()
  362. }
  363. if m > n {
  364. // Use the QR factorization to compute the condition number.
  365. var qr QR
  366. qr.factorize(a, lnorm)
  367. return qr.Cond()
  368. }
  369. // Use the LQ factorization to compute the condition number.
  370. var lq LQ
  371. lq.factorize(a, lnorm)
  372. return lq.Cond()
  373. }
  374. // Det returns the determinant of the matrix a. In many expressions using LogDet
  375. // will be more numerically stable.
  376. func Det(a Matrix) float64 {
  377. det, sign := LogDet(a)
  378. return math.Exp(det) * sign
  379. }
  380. // Dot returns the sum of the element-wise product of a and b.
  381. // Dot panics if the matrix sizes are unequal.
  382. func Dot(a, b Vector) float64 {
  383. la := a.Len()
  384. lb := b.Len()
  385. if la != lb {
  386. panic(ErrShape)
  387. }
  388. if arv, ok := a.(RawVectorer); ok {
  389. if brv, ok := b.(RawVectorer); ok {
  390. return blas64.Dot(arv.RawVector(), brv.RawVector())
  391. }
  392. }
  393. var sum float64
  394. for i := 0; i < la; i++ {
  395. sum += a.At(i, 0) * b.At(i, 0)
  396. }
  397. return sum
  398. }
  399. // Equal returns whether the matrices a and b have the same size
  400. // and are element-wise equal.
  401. func Equal(a, b Matrix) bool {
  402. ar, ac := a.Dims()
  403. br, bc := b.Dims()
  404. if ar != br || ac != bc {
  405. return false
  406. }
  407. aU, aTrans := untranspose(a)
  408. bU, bTrans := untranspose(b)
  409. if rma, ok := aU.(RawMatrixer); ok {
  410. if rmb, ok := bU.(RawMatrixer); ok {
  411. ra := rma.RawMatrix()
  412. rb := rmb.RawMatrix()
  413. if aTrans == bTrans {
  414. for i := 0; i < ra.Rows; i++ {
  415. for j := 0; j < ra.Cols; j++ {
  416. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  417. return false
  418. }
  419. }
  420. }
  421. return true
  422. }
  423. for i := 0; i < ra.Rows; i++ {
  424. for j := 0; j < ra.Cols; j++ {
  425. if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
  426. return false
  427. }
  428. }
  429. }
  430. return true
  431. }
  432. }
  433. if rma, ok := aU.(RawSymmetricer); ok {
  434. if rmb, ok := bU.(RawSymmetricer); ok {
  435. ra := rma.RawSymmetric()
  436. rb := rmb.RawSymmetric()
  437. // Symmetric matrices are always upper and equal to their transpose.
  438. for i := 0; i < ra.N; i++ {
  439. for j := i; j < ra.N; j++ {
  440. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  441. return false
  442. }
  443. }
  444. }
  445. return true
  446. }
  447. }
  448. if ra, ok := aU.(*VecDense); ok {
  449. if rb, ok := bU.(*VecDense); ok {
  450. // If the raw vectors are the same length they must either both be
  451. // transposed or both not transposed (or have length 1).
  452. for i := 0; i < ra.mat.N; i++ {
  453. if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
  454. return false
  455. }
  456. }
  457. return true
  458. }
  459. }
  460. for i := 0; i < ar; i++ {
  461. for j := 0; j < ac; j++ {
  462. if a.At(i, j) != b.At(i, j) {
  463. return false
  464. }
  465. }
  466. }
  467. return true
  468. }
  469. // EqualApprox returns whether the matrices a and b have the same size and contain all equal
  470. // elements with tolerance for element-wise equality specified by epsilon. Matrices
  471. // with non-equal shapes are not equal.
  472. func EqualApprox(a, b Matrix, epsilon float64) bool {
  473. ar, ac := a.Dims()
  474. br, bc := b.Dims()
  475. if ar != br || ac != bc {
  476. return false
  477. }
  478. aU, aTrans := untranspose(a)
  479. bU, bTrans := untranspose(b)
  480. if rma, ok := aU.(RawMatrixer); ok {
  481. if rmb, ok := bU.(RawMatrixer); ok {
  482. ra := rma.RawMatrix()
  483. rb := rmb.RawMatrix()
  484. if aTrans == bTrans {
  485. for i := 0; i < ra.Rows; i++ {
  486. for j := 0; j < ra.Cols; j++ {
  487. if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  488. return false
  489. }
  490. }
  491. }
  492. return true
  493. }
  494. for i := 0; i < ra.Rows; i++ {
  495. for j := 0; j < ra.Cols; j++ {
  496. if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
  497. return false
  498. }
  499. }
  500. }
  501. return true
  502. }
  503. }
  504. if rma, ok := aU.(RawSymmetricer); ok {
  505. if rmb, ok := bU.(RawSymmetricer); ok {
  506. ra := rma.RawSymmetric()
  507. rb := rmb.RawSymmetric()
  508. // Symmetric matrices are always upper and equal to their transpose.
  509. for i := 0; i < ra.N; i++ {
  510. for j := i; j < ra.N; j++ {
  511. if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  512. return false
  513. }
  514. }
  515. }
  516. return true
  517. }
  518. }
  519. if ra, ok := aU.(*VecDense); ok {
  520. if rb, ok := bU.(*VecDense); ok {
  521. // If the raw vectors are the same length they must either both be
  522. // transposed or both not transposed (or have length 1).
  523. for i := 0; i < ra.mat.N; i++ {
  524. if !scalar.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
  525. return false
  526. }
  527. }
  528. return true
  529. }
  530. }
  531. for i := 0; i < ar; i++ {
  532. for j := 0; j < ac; j++ {
  533. if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
  534. return false
  535. }
  536. }
  537. }
  538. return true
  539. }
  540. // LogDet returns the log of the determinant and the sign of the determinant
  541. // for the matrix that has been factorized. Numerical stability in product and
  542. // division expressions is generally improved by working in log space.
  543. func LogDet(a Matrix) (det float64, sign float64) {
  544. // TODO(btracey): Add specialized routines for TriDense, etc.
  545. var lu LU
  546. lu.Factorize(a)
  547. return lu.LogDet()
  548. }
  549. // Max returns the largest element value of the matrix A.
  550. // Max will panic with matrix.ErrShape if the matrix has zero size.
  551. func Max(a Matrix) float64 {
  552. r, c := a.Dims()
  553. if r == 0 || c == 0 {
  554. panic(ErrShape)
  555. }
  556. // Max(A) = Max(Aᵀ)
  557. aU, _ := untranspose(a)
  558. switch m := aU.(type) {
  559. case RawMatrixer:
  560. rm := m.RawMatrix()
  561. max := math.Inf(-1)
  562. for i := 0; i < rm.Rows; i++ {
  563. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  564. if v > max {
  565. max = v
  566. }
  567. }
  568. }
  569. return max
  570. case RawTriangular:
  571. rm := m.RawTriangular()
  572. // The max of a triangular is at least 0 unless the size is 1.
  573. if rm.N == 1 {
  574. return rm.Data[0]
  575. }
  576. max := 0.0
  577. if rm.Uplo == blas.Upper {
  578. for i := 0; i < rm.N; i++ {
  579. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  580. if v > max {
  581. max = v
  582. }
  583. }
  584. }
  585. return max
  586. }
  587. for i := 0; i < rm.N; i++ {
  588. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  589. if v > max {
  590. max = v
  591. }
  592. }
  593. }
  594. return max
  595. case RawSymmetricer:
  596. rm := m.RawSymmetric()
  597. if rm.Uplo != blas.Upper {
  598. panic(badSymTriangle)
  599. }
  600. max := math.Inf(-1)
  601. for i := 0; i < rm.N; i++ {
  602. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  603. if v > max {
  604. max = v
  605. }
  606. }
  607. }
  608. return max
  609. default:
  610. r, c := aU.Dims()
  611. max := math.Inf(-1)
  612. for i := 0; i < r; i++ {
  613. for j := 0; j < c; j++ {
  614. v := aU.At(i, j)
  615. if v > max {
  616. max = v
  617. }
  618. }
  619. }
  620. return max
  621. }
  622. }
  623. // Min returns the smallest element value of the matrix A.
  624. // Min will panic with matrix.ErrShape if the matrix has zero size.
  625. func Min(a Matrix) float64 {
  626. r, c := a.Dims()
  627. if r == 0 || c == 0 {
  628. panic(ErrShape)
  629. }
  630. // Min(A) = Min(Aᵀ)
  631. aU, _ := untranspose(a)
  632. switch m := aU.(type) {
  633. case RawMatrixer:
  634. rm := m.RawMatrix()
  635. min := math.Inf(1)
  636. for i := 0; i < rm.Rows; i++ {
  637. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  638. if v < min {
  639. min = v
  640. }
  641. }
  642. }
  643. return min
  644. case RawTriangular:
  645. rm := m.RawTriangular()
  646. // The min of a triangular is at most 0 unless the size is 1.
  647. if rm.N == 1 {
  648. return rm.Data[0]
  649. }
  650. min := 0.0
  651. if rm.Uplo == blas.Upper {
  652. for i := 0; i < rm.N; i++ {
  653. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  654. if v < min {
  655. min = v
  656. }
  657. }
  658. }
  659. return min
  660. }
  661. for i := 0; i < rm.N; i++ {
  662. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  663. if v < min {
  664. min = v
  665. }
  666. }
  667. }
  668. return min
  669. case RawSymmetricer:
  670. rm := m.RawSymmetric()
  671. if rm.Uplo != blas.Upper {
  672. panic(badSymTriangle)
  673. }
  674. min := math.Inf(1)
  675. for i := 0; i < rm.N; i++ {
  676. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  677. if v < min {
  678. min = v
  679. }
  680. }
  681. }
  682. return min
  683. default:
  684. r, c := aU.Dims()
  685. min := math.Inf(1)
  686. for i := 0; i < r; i++ {
  687. for j := 0; j < c; j++ {
  688. v := aU.At(i, j)
  689. if v < min {
  690. min = v
  691. }
  692. }
  693. }
  694. return min
  695. }
  696. }
  697. // Norm returns the specified (induced) norm of the matrix a. See
  698. // https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm.
  699. //
  700. // Valid norms are:
  701. // 1 - The maximum absolute column sum
  702. // 2 - Frobenius norm, the square root of the sum of the squares of the elements.
  703. // Inf - The maximum absolute row sum.
  704. // Norm will panic with ErrNormOrder if an illegal norm order is specified and
  705. // with matrix.ErrShape if the matrix has zero size.
  706. func Norm(a Matrix, norm float64) float64 {
  707. r, c := a.Dims()
  708. if r == 0 || c == 0 {
  709. panic(ErrShape)
  710. }
  711. aU, aTrans := untranspose(a)
  712. var work []float64
  713. switch rma := aU.(type) {
  714. case RawMatrixer:
  715. rm := rma.RawMatrix()
  716. n := normLapack(norm, aTrans)
  717. if n == lapack.MaxColumnSum {
  718. work = getFloats(rm.Cols, false)
  719. defer putFloats(work)
  720. }
  721. return lapack64.Lange(n, rm, work)
  722. case RawTriangular:
  723. rm := rma.RawTriangular()
  724. n := normLapack(norm, aTrans)
  725. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  726. work = getFloats(rm.N, false)
  727. defer putFloats(work)
  728. }
  729. return lapack64.Lantr(n, rm, work)
  730. case RawSymmetricer:
  731. rm := rma.RawSymmetric()
  732. n := normLapack(norm, aTrans)
  733. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  734. work = getFloats(rm.N, false)
  735. defer putFloats(work)
  736. }
  737. return lapack64.Lansy(n, rm, work)
  738. case *VecDense:
  739. rv := rma.RawVector()
  740. switch norm {
  741. default:
  742. panic(ErrNormOrder)
  743. case 1:
  744. if aTrans {
  745. imax := blas64.Iamax(rv)
  746. return math.Abs(rma.At(imax, 0))
  747. }
  748. return blas64.Asum(rv)
  749. case 2:
  750. return blas64.Nrm2(rv)
  751. case math.Inf(1):
  752. if aTrans {
  753. return blas64.Asum(rv)
  754. }
  755. imax := blas64.Iamax(rv)
  756. return math.Abs(rma.At(imax, 0))
  757. }
  758. }
  759. switch norm {
  760. default:
  761. panic(ErrNormOrder)
  762. case 1:
  763. var max float64
  764. for j := 0; j < c; j++ {
  765. var sum float64
  766. for i := 0; i < r; i++ {
  767. sum += math.Abs(a.At(i, j))
  768. }
  769. if sum > max {
  770. max = sum
  771. }
  772. }
  773. return max
  774. case 2:
  775. var sum float64
  776. for i := 0; i < r; i++ {
  777. for j := 0; j < c; j++ {
  778. v := a.At(i, j)
  779. sum += v * v
  780. }
  781. }
  782. return math.Sqrt(sum)
  783. case math.Inf(1):
  784. var max float64
  785. for i := 0; i < r; i++ {
  786. var sum float64
  787. for j := 0; j < c; j++ {
  788. sum += math.Abs(a.At(i, j))
  789. }
  790. if sum > max {
  791. max = sum
  792. }
  793. }
  794. return max
  795. }
  796. }
  797. // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
  798. func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
  799. switch norm {
  800. case 1:
  801. n := lapack.MaxColumnSum
  802. if aTrans {
  803. n = lapack.MaxRowSum
  804. }
  805. return n
  806. case 2:
  807. return lapack.Frobenius
  808. case math.Inf(1):
  809. n := lapack.MaxRowSum
  810. if aTrans {
  811. n = lapack.MaxColumnSum
  812. }
  813. return n
  814. default:
  815. panic(ErrNormOrder)
  816. }
  817. }
  818. // Sum returns the sum of the elements of the matrix.
  819. func Sum(a Matrix) float64 {
  820. var sum float64
  821. aU, _ := untranspose(a)
  822. switch rma := aU.(type) {
  823. case RawSymmetricer:
  824. rm := rma.RawSymmetric()
  825. for i := 0; i < rm.N; i++ {
  826. // Diagonals count once while off-diagonals count twice.
  827. sum += rm.Data[i*rm.Stride+i]
  828. var s float64
  829. for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] {
  830. s += v
  831. }
  832. sum += 2 * s
  833. }
  834. return sum
  835. case RawTriangular:
  836. rm := rma.RawTriangular()
  837. var startIdx, endIdx int
  838. for i := 0; i < rm.N; i++ {
  839. // Start and end index for this triangle-row.
  840. switch rm.Uplo {
  841. case blas.Upper:
  842. startIdx = i
  843. endIdx = rm.N
  844. case blas.Lower:
  845. startIdx = 0
  846. endIdx = i + 1
  847. default:
  848. panic(badTriangle)
  849. }
  850. for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] {
  851. sum += v
  852. }
  853. }
  854. return sum
  855. case RawMatrixer:
  856. rm := rma.RawMatrix()
  857. for i := 0; i < rm.Rows; i++ {
  858. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  859. sum += v
  860. }
  861. }
  862. return sum
  863. case *VecDense:
  864. rm := rma.RawVector()
  865. for i := 0; i < rm.N; i++ {
  866. sum += rm.Data[i*rm.Inc]
  867. }
  868. return sum
  869. default:
  870. r, c := a.Dims()
  871. for i := 0; i < r; i++ {
  872. for j := 0; j < c; j++ {
  873. sum += a.At(i, j)
  874. }
  875. }
  876. return sum
  877. }
  878. }
  879. // A Tracer can compute the trace of the matrix. Trace must panic if the
  880. // matrix is not square.
  881. type Tracer interface {
  882. Trace() float64
  883. }
  884. // Trace returns the trace of the matrix. Trace will panic if the
  885. // matrix is not square. If a is a Tracer, its Trace method will be
  886. // used to calculate the matrix trace.
  887. func Trace(a Matrix) float64 {
  888. m, _ := untransposeExtract(a)
  889. if t, ok := m.(Tracer); ok {
  890. return t.Trace()
  891. }
  892. r, c := a.Dims()
  893. if r != c {
  894. panic(ErrSquare)
  895. }
  896. var v float64
  897. for i := 0; i < r; i++ {
  898. v += a.At(i, i)
  899. }
  900. return v
  901. }
  902. func min(a, b int) int {
  903. if a < b {
  904. return a
  905. }
  906. return b
  907. }
  908. func max(a, b int) int {
  909. if a > b {
  910. return a
  911. }
  912. return b
  913. }
  914. // use returns a float64 slice with l elements, using f if it
  915. // has the necessary capacity, otherwise creating a new slice.
  916. func use(f []float64, l int) []float64 {
  917. if l <= cap(f) {
  918. return f[:l]
  919. }
  920. return make([]float64, l)
  921. }
  922. // useZeroed returns a float64 slice with l elements, using f if it
  923. // has the necessary capacity, otherwise creating a new slice. The
  924. // elements of the returned slice are guaranteed to be zero.
  925. func useZeroed(f []float64, l int) []float64 {
  926. if l <= cap(f) {
  927. f = f[:l]
  928. zero(f)
  929. return f
  930. }
  931. return make([]float64, l)
  932. }
  933. // zero zeros the given slice's elements.
  934. func zero(f []float64) {
  935. for i := range f {
  936. f[i] = 0
  937. }
  938. }
  939. // useInt returns an int slice with l elements, using i if it
  940. // has the necessary capacity, otherwise creating a new slice.
  941. func useInt(i []int, l int) []int {
  942. if l <= cap(i) {
  943. return i[:l]
  944. }
  945. return make([]int, l)
  946. }