dense_arithmetic.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/lapack/lapack64"
  10. )
  11. // Add adds a and b element-wise, placing the result in the receiver. Add
  12. // will panic if the two matrices do not have the same shape.
  13. func (m *Dense) Add(a, b Matrix) {
  14. ar, ac := a.Dims()
  15. br, bc := b.Dims()
  16. if ar != br || ac != bc {
  17. panic(ErrShape)
  18. }
  19. aU, _ := untransposeExtract(a)
  20. bU, _ := untransposeExtract(b)
  21. m.reuseAsNonZeroed(ar, ac)
  22. if arm, ok := a.(*Dense); ok {
  23. if brm, ok := b.(*Dense); ok {
  24. amat, bmat := arm.mat, brm.mat
  25. if m != aU {
  26. m.checkOverlap(amat)
  27. }
  28. if m != bU {
  29. m.checkOverlap(bmat)
  30. }
  31. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  32. for i, v := range amat.Data[ja : ja+ac] {
  33. m.mat.Data[i+jm] = v + bmat.Data[i+jb]
  34. }
  35. }
  36. return
  37. }
  38. }
  39. m.checkOverlapMatrix(aU)
  40. m.checkOverlapMatrix(bU)
  41. var restore func()
  42. if m == aU {
  43. m, restore = m.isolatedWorkspace(aU)
  44. defer restore()
  45. } else if m == bU {
  46. m, restore = m.isolatedWorkspace(bU)
  47. defer restore()
  48. }
  49. for r := 0; r < ar; r++ {
  50. for c := 0; c < ac; c++ {
  51. m.set(r, c, a.At(r, c)+b.At(r, c))
  52. }
  53. }
  54. }
  55. // Sub subtracts the matrix b from a, placing the result in the receiver. Sub
  56. // will panic if the two matrices do not have the same shape.
  57. func (m *Dense) Sub(a, b Matrix) {
  58. ar, ac := a.Dims()
  59. br, bc := b.Dims()
  60. if ar != br || ac != bc {
  61. panic(ErrShape)
  62. }
  63. aU, _ := untransposeExtract(a)
  64. bU, _ := untransposeExtract(b)
  65. m.reuseAsNonZeroed(ar, ac)
  66. if arm, ok := a.(*Dense); ok {
  67. if brm, ok := b.(*Dense); ok {
  68. amat, bmat := arm.mat, brm.mat
  69. if m != aU {
  70. m.checkOverlap(amat)
  71. }
  72. if m != bU {
  73. m.checkOverlap(bmat)
  74. }
  75. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  76. for i, v := range amat.Data[ja : ja+ac] {
  77. m.mat.Data[i+jm] = v - bmat.Data[i+jb]
  78. }
  79. }
  80. return
  81. }
  82. }
  83. m.checkOverlapMatrix(aU)
  84. m.checkOverlapMatrix(bU)
  85. var restore func()
  86. if m == aU {
  87. m, restore = m.isolatedWorkspace(aU)
  88. defer restore()
  89. } else if m == bU {
  90. m, restore = m.isolatedWorkspace(bU)
  91. defer restore()
  92. }
  93. for r := 0; r < ar; r++ {
  94. for c := 0; c < ac; c++ {
  95. m.set(r, c, a.At(r, c)-b.At(r, c))
  96. }
  97. }
  98. }
  99. // MulElem performs element-wise multiplication of a and b, placing the result
  100. // in the receiver. MulElem will panic if the two matrices do not have the same
  101. // shape.
  102. func (m *Dense) MulElem(a, b Matrix) {
  103. ar, ac := a.Dims()
  104. br, bc := b.Dims()
  105. if ar != br || ac != bc {
  106. panic(ErrShape)
  107. }
  108. aU, _ := untransposeExtract(a)
  109. bU, _ := untransposeExtract(b)
  110. m.reuseAsNonZeroed(ar, ac)
  111. if arm, ok := a.(*Dense); ok {
  112. if brm, ok := b.(*Dense); ok {
  113. amat, bmat := arm.mat, brm.mat
  114. if m != aU {
  115. m.checkOverlap(amat)
  116. }
  117. if m != bU {
  118. m.checkOverlap(bmat)
  119. }
  120. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  121. for i, v := range amat.Data[ja : ja+ac] {
  122. m.mat.Data[i+jm] = v * bmat.Data[i+jb]
  123. }
  124. }
  125. return
  126. }
  127. }
  128. m.checkOverlapMatrix(aU)
  129. m.checkOverlapMatrix(bU)
  130. var restore func()
  131. if m == aU {
  132. m, restore = m.isolatedWorkspace(aU)
  133. defer restore()
  134. } else if m == bU {
  135. m, restore = m.isolatedWorkspace(bU)
  136. defer restore()
  137. }
  138. for r := 0; r < ar; r++ {
  139. for c := 0; c < ac; c++ {
  140. m.set(r, c, a.At(r, c)*b.At(r, c))
  141. }
  142. }
  143. }
  144. // DivElem performs element-wise division of a by b, placing the result
  145. // in the receiver. DivElem will panic if the two matrices do not have the same
  146. // shape.
  147. func (m *Dense) DivElem(a, b Matrix) {
  148. ar, ac := a.Dims()
  149. br, bc := b.Dims()
  150. if ar != br || ac != bc {
  151. panic(ErrShape)
  152. }
  153. aU, _ := untransposeExtract(a)
  154. bU, _ := untransposeExtract(b)
  155. m.reuseAsNonZeroed(ar, ac)
  156. if arm, ok := a.(*Dense); ok {
  157. if brm, ok := b.(*Dense); ok {
  158. amat, bmat := arm.mat, brm.mat
  159. if m != aU {
  160. m.checkOverlap(amat)
  161. }
  162. if m != bU {
  163. m.checkOverlap(bmat)
  164. }
  165. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  166. for i, v := range amat.Data[ja : ja+ac] {
  167. m.mat.Data[i+jm] = v / bmat.Data[i+jb]
  168. }
  169. }
  170. return
  171. }
  172. }
  173. m.checkOverlapMatrix(aU)
  174. m.checkOverlapMatrix(bU)
  175. var restore func()
  176. if m == aU {
  177. m, restore = m.isolatedWorkspace(aU)
  178. defer restore()
  179. } else if m == bU {
  180. m, restore = m.isolatedWorkspace(bU)
  181. defer restore()
  182. }
  183. for r := 0; r < ar; r++ {
  184. for c := 0; c < ac; c++ {
  185. m.set(r, c, a.At(r, c)/b.At(r, c))
  186. }
  187. }
  188. }
  189. // Inverse computes the inverse of the matrix a, storing the result into the
  190. // receiver. If a is ill-conditioned, a Condition error will be returned.
  191. // Note that matrix inversion is numerically unstable, and should generally
  192. // be avoided where possible, for example by using the Solve routines.
  193. func (m *Dense) Inverse(a Matrix) error {
  194. // TODO(btracey): Special case for RawTriangular, etc.
  195. r, c := a.Dims()
  196. if r != c {
  197. panic(ErrSquare)
  198. }
  199. m.reuseAsNonZeroed(a.Dims())
  200. aU, aTrans := untransposeExtract(a)
  201. switch rm := aU.(type) {
  202. case *Dense:
  203. if m != aU || aTrans {
  204. if m == aU || m.checkOverlap(rm.mat) {
  205. tmp := getWorkspace(r, c, false)
  206. tmp.Copy(a)
  207. m.Copy(tmp)
  208. putWorkspace(tmp)
  209. break
  210. }
  211. m.Copy(a)
  212. }
  213. default:
  214. m.Copy(a)
  215. }
  216. ipiv := getInts(r, false)
  217. defer putInts(ipiv)
  218. ok := lapack64.Getrf(m.mat, ipiv)
  219. if !ok {
  220. return Condition(math.Inf(1))
  221. }
  222. work := getFloats(4*r, false) // must be at least 4*r for cond.
  223. lapack64.Getri(m.mat, ipiv, work, -1)
  224. if int(work[0]) > 4*r {
  225. l := int(work[0])
  226. putFloats(work)
  227. work = getFloats(l, false)
  228. } else {
  229. work = work[:4*r]
  230. }
  231. defer putFloats(work)
  232. lapack64.Getri(m.mat, ipiv, work, len(work))
  233. norm := lapack64.Lange(CondNorm, m.mat, work)
  234. rcond := lapack64.Gecon(CondNorm, m.mat, norm, work, ipiv) // reuse ipiv
  235. if rcond == 0 {
  236. return Condition(math.Inf(1))
  237. }
  238. cond := 1 / rcond
  239. if cond > ConditionTolerance {
  240. return Condition(cond)
  241. }
  242. return nil
  243. }
  244. // Mul takes the matrix product of a and b, placing the result in the receiver.
  245. // If the number of columns in a does not equal the number of rows in b, Mul will panic.
  246. func (m *Dense) Mul(a, b Matrix) {
  247. ar, ac := a.Dims()
  248. br, bc := b.Dims()
  249. if ac != br {
  250. panic(ErrShape)
  251. }
  252. aU, aTrans := untransposeExtract(a)
  253. bU, bTrans := untransposeExtract(b)
  254. m.reuseAsNonZeroed(ar, bc)
  255. var restore func()
  256. if m == aU {
  257. m, restore = m.isolatedWorkspace(aU)
  258. defer restore()
  259. } else if m == bU {
  260. m, restore = m.isolatedWorkspace(bU)
  261. defer restore()
  262. }
  263. aT := blas.NoTrans
  264. if aTrans {
  265. aT = blas.Trans
  266. }
  267. bT := blas.NoTrans
  268. if bTrans {
  269. bT = blas.Trans
  270. }
  271. // Some of the cases do not have a transpose option, so create
  272. // temporary memory.
  273. // C = Aᵀ * B = (Bᵀ * A)ᵀ
  274. // Cᵀ = Bᵀ * A.
  275. if aU, ok := aU.(*Dense); ok {
  276. if restore == nil {
  277. m.checkOverlap(aU.mat)
  278. }
  279. switch bU := bU.(type) {
  280. case *Dense:
  281. if restore == nil {
  282. m.checkOverlap(bU.mat)
  283. }
  284. blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat)
  285. return
  286. case *SymDense:
  287. if aTrans {
  288. c := getWorkspace(ac, ar, false)
  289. blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat)
  290. strictCopy(m, c.T())
  291. putWorkspace(c)
  292. return
  293. }
  294. blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat)
  295. return
  296. case *TriDense:
  297. // Trmm updates in place, so copy aU first.
  298. if aTrans {
  299. c := getWorkspace(ac, ar, false)
  300. var tmp Dense
  301. tmp.SetRawMatrix(aU.mat)
  302. c.Copy(&tmp)
  303. bT := blas.Trans
  304. if bTrans {
  305. bT = blas.NoTrans
  306. }
  307. blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat)
  308. strictCopy(m, c.T())
  309. putWorkspace(c)
  310. return
  311. }
  312. m.Copy(a)
  313. blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat)
  314. return
  315. case *VecDense:
  316. m.checkOverlap(bU.asGeneral())
  317. bvec := bU.RawVector()
  318. if bTrans {
  319. // {ar,1} x {1,bc}, which is not a vector.
  320. // Instead, construct B as a General.
  321. bmat := blas64.General{
  322. Rows: bc,
  323. Cols: 1,
  324. Stride: bvec.Inc,
  325. Data: bvec.Data,
  326. }
  327. blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat)
  328. return
  329. }
  330. cvec := blas64.Vector{
  331. Inc: m.mat.Stride,
  332. Data: m.mat.Data,
  333. }
  334. blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec)
  335. return
  336. }
  337. }
  338. if bU, ok := bU.(*Dense); ok {
  339. if restore == nil {
  340. m.checkOverlap(bU.mat)
  341. }
  342. switch aU := aU.(type) {
  343. case *SymDense:
  344. if bTrans {
  345. c := getWorkspace(bc, br, false)
  346. blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat)
  347. strictCopy(m, c.T())
  348. putWorkspace(c)
  349. return
  350. }
  351. blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat)
  352. return
  353. case *TriDense:
  354. // Trmm updates in place, so copy bU first.
  355. if bTrans {
  356. c := getWorkspace(bc, br, false)
  357. var tmp Dense
  358. tmp.SetRawMatrix(bU.mat)
  359. c.Copy(&tmp)
  360. aT := blas.Trans
  361. if aTrans {
  362. aT = blas.NoTrans
  363. }
  364. blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat)
  365. strictCopy(m, c.T())
  366. putWorkspace(c)
  367. return
  368. }
  369. m.Copy(b)
  370. blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat)
  371. return
  372. case *VecDense:
  373. m.checkOverlap(aU.asGeneral())
  374. avec := aU.RawVector()
  375. if aTrans {
  376. // {1,ac} x {ac, bc}
  377. // Transpose B so that the vector is on the right.
  378. cvec := blas64.Vector{
  379. Inc: 1,
  380. Data: m.mat.Data,
  381. }
  382. bT := blas.Trans
  383. if bTrans {
  384. bT = blas.NoTrans
  385. }
  386. blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec)
  387. return
  388. }
  389. // {ar,1} x {1,bc} which is not a vector result.
  390. // Instead, construct A as a General.
  391. amat := blas64.General{
  392. Rows: ar,
  393. Cols: 1,
  394. Stride: avec.Inc,
  395. Data: avec.Data,
  396. }
  397. blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat)
  398. return
  399. }
  400. }
  401. m.checkOverlapMatrix(aU)
  402. m.checkOverlapMatrix(bU)
  403. row := getFloats(ac, false)
  404. defer putFloats(row)
  405. for r := 0; r < ar; r++ {
  406. for i := range row {
  407. row[i] = a.At(r, i)
  408. }
  409. for c := 0; c < bc; c++ {
  410. var v float64
  411. for i, e := range row {
  412. v += e * b.At(i, c)
  413. }
  414. m.mat.Data[r*m.mat.Stride+c] = v
  415. }
  416. }
  417. }
  418. // strictCopy copies a into m panicking if the shape of a and m differ.
  419. func strictCopy(m *Dense, a Matrix) {
  420. r, c := m.Copy(a)
  421. if r != m.mat.Rows || c != m.mat.Cols {
  422. // Panic with a string since this
  423. // is not a user-facing panic.
  424. panic(ErrShape.Error())
  425. }
  426. }
  427. // Exp calculates the exponential of the matrix a, e^a, placing the result
  428. // in the receiver. Exp will panic with matrix.ErrShape if a is not square.
  429. func (m *Dense) Exp(a Matrix) {
  430. // The implementation used here is from Functions of Matrices: Theory and Computation
  431. // Chapter 10, Algorithm 10.20. https://doi.org/10.1137/1.9780898717778.ch10
  432. r, c := a.Dims()
  433. if r != c {
  434. panic(ErrShape)
  435. }
  436. m.reuseAsNonZeroed(r, r)
  437. if r == 1 {
  438. m.mat.Data[0] = math.Exp(a.At(0, 0))
  439. return
  440. }
  441. pade := []struct {
  442. theta float64
  443. b []float64
  444. }{
  445. {theta: 0.015, b: []float64{
  446. 120, 60, 12, 1,
  447. }},
  448. {theta: 0.25, b: []float64{
  449. 30240, 15120, 3360, 420, 30, 1,
  450. }},
  451. {theta: 0.95, b: []float64{
  452. 17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1,
  453. }},
  454. {theta: 2.1, b: []float64{
  455. 17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1,
  456. }},
  457. }
  458. a1 := m
  459. a1.Copy(a)
  460. v := getWorkspace(r, r, true)
  461. vraw := v.RawMatrix()
  462. n := r * r
  463. vvec := blas64.Vector{N: n, Inc: 1, Data: vraw.Data}
  464. defer putWorkspace(v)
  465. u := getWorkspace(r, r, true)
  466. uraw := u.RawMatrix()
  467. uvec := blas64.Vector{N: n, Inc: 1, Data: uraw.Data}
  468. defer putWorkspace(u)
  469. a2 := getWorkspace(r, r, false)
  470. defer putWorkspace(a2)
  471. n1 := Norm(a, 1)
  472. for i, t := range pade {
  473. if n1 > t.theta {
  474. continue
  475. }
  476. // This loop only executes once, so
  477. // this is not as horrible as it looks.
  478. p := getWorkspace(r, r, true)
  479. praw := p.RawMatrix()
  480. pvec := blas64.Vector{N: n, Inc: 1, Data: praw.Data}
  481. defer putWorkspace(p)
  482. for k := 0; k < r; k++ {
  483. p.set(k, k, 1)
  484. v.set(k, k, t.b[0])
  485. u.set(k, k, t.b[1])
  486. }
  487. a2.Mul(a1, a1)
  488. for j := 0; j <= i; j++ {
  489. p.Mul(p, a2)
  490. blas64.Axpy(t.b[2*j+2], pvec, vvec)
  491. blas64.Axpy(t.b[2*j+3], pvec, uvec)
  492. }
  493. u.Mul(a1, u)
  494. // Use p as a workspace here and
  495. // rename u for the second call's
  496. // receiver.
  497. vmu, vpu := u, p
  498. vpu.Add(v, u)
  499. vmu.Sub(v, u)
  500. _ = m.Solve(vmu, vpu)
  501. return
  502. }
  503. // Remaining Padé table line.
  504. const theta13 = 5.4
  505. b := [...]float64{
  506. 64764752532480000, 32382376266240000, 7771770303897600, 1187353796428800,
  507. 129060195264000, 10559470521600, 670442572800, 33522128640,
  508. 1323241920, 40840800, 960960, 16380, 182, 1,
  509. }
  510. s := math.Log2(n1 / theta13)
  511. if s >= 0 {
  512. s = math.Ceil(s)
  513. a1.Scale(1/math.Pow(2, s), a1)
  514. }
  515. a2.Mul(a1, a1)
  516. i := getWorkspace(r, r, true)
  517. for j := 0; j < r; j++ {
  518. i.set(j, j, 1)
  519. }
  520. iraw := i.RawMatrix()
  521. ivec := blas64.Vector{N: n, Inc: 1, Data: iraw.Data}
  522. defer putWorkspace(i)
  523. a2raw := a2.RawMatrix()
  524. a2vec := blas64.Vector{N: n, Inc: 1, Data: a2raw.Data}
  525. a4 := getWorkspace(r, r, false)
  526. a4raw := a4.RawMatrix()
  527. a4vec := blas64.Vector{N: n, Inc: 1, Data: a4raw.Data}
  528. defer putWorkspace(a4)
  529. a4.Mul(a2, a2)
  530. a6 := getWorkspace(r, r, false)
  531. a6raw := a6.RawMatrix()
  532. a6vec := blas64.Vector{N: n, Inc: 1, Data: a6raw.Data}
  533. defer putWorkspace(a6)
  534. a6.Mul(a2, a4)
  535. // V = A_6(b_12*A_6 + b_10*A_4 + b_8*A_2) + b_6*A_6 + b_4*A_4 + b_2*A_2 +b_0*I
  536. blas64.Axpy(b[12], a6vec, vvec)
  537. blas64.Axpy(b[10], a4vec, vvec)
  538. blas64.Axpy(b[8], a2vec, vvec)
  539. v.Mul(v, a6)
  540. blas64.Axpy(b[6], a6vec, vvec)
  541. blas64.Axpy(b[4], a4vec, vvec)
  542. blas64.Axpy(b[2], a2vec, vvec)
  543. blas64.Axpy(b[0], ivec, vvec)
  544. // U = A(A_6(b_13*A_6 + b_11*A_4 + b_9*A_2) + b_7*A_6 + b_5*A_4 + b_2*A_3 +b_1*I)
  545. blas64.Axpy(b[13], a6vec, uvec)
  546. blas64.Axpy(b[11], a4vec, uvec)
  547. blas64.Axpy(b[9], a2vec, uvec)
  548. u.Mul(u, a6)
  549. blas64.Axpy(b[7], a6vec, uvec)
  550. blas64.Axpy(b[5], a4vec, uvec)
  551. blas64.Axpy(b[3], a2vec, uvec)
  552. blas64.Axpy(b[1], ivec, uvec)
  553. u.Mul(u, a1)
  554. // Use i as a workspace here and
  555. // rename u for the second call's
  556. // receiver.
  557. vmu, vpu := u, i
  558. vpu.Add(v, u)
  559. vmu.Sub(v, u)
  560. _ = m.Solve(vmu, vpu)
  561. for ; s > 0; s-- {
  562. m.Mul(m, m)
  563. }
  564. }
  565. // Pow calculates the integral power of the matrix a to n, placing the result
  566. // in the receiver. Pow will panic if n is negative or if a is not square.
  567. func (m *Dense) Pow(a Matrix, n int) {
  568. if n < 0 {
  569. panic("mat: illegal power")
  570. }
  571. r, c := a.Dims()
  572. if r != c {
  573. panic(ErrShape)
  574. }
  575. m.reuseAsNonZeroed(r, c)
  576. // Take possible fast paths.
  577. switch n {
  578. case 0:
  579. for i := 0; i < r; i++ {
  580. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
  581. m.mat.Data[i*m.mat.Stride+i] = 1
  582. }
  583. return
  584. case 1:
  585. m.Copy(a)
  586. return
  587. case 2:
  588. m.Mul(a, a)
  589. return
  590. }
  591. // Perform iterative exponentiation by squaring in work space.
  592. w := getWorkspace(r, r, false)
  593. w.Copy(a)
  594. s := getWorkspace(r, r, false)
  595. s.Copy(a)
  596. x := getWorkspace(r, r, false)
  597. for n--; n > 0; n >>= 1 {
  598. if n&1 != 0 {
  599. x.Mul(w, s)
  600. w, x = x, w
  601. }
  602. if n != 1 {
  603. x.Mul(s, s)
  604. s, x = x, s
  605. }
  606. }
  607. m.Copy(w)
  608. putWorkspace(w)
  609. putWorkspace(s)
  610. putWorkspace(x)
  611. }
  612. // Kronecker calculates the Kronecker product of a and b, placing the result in
  613. // the receiver.
  614. func (m *Dense) Kronecker(a, b Matrix) {
  615. ra, ca := a.Dims()
  616. rb, cb := b.Dims()
  617. m.reuseAsNonZeroed(ra*rb, ca*cb)
  618. for i := 0; i < ra; i++ {
  619. for j := 0; j < ca; j++ {
  620. m.slice(i*rb, (i+1)*rb, j*cb, (j+1)*cb).Scale(a.At(i, j), b)
  621. }
  622. }
  623. }
  624. // Scale multiplies the elements of a by f, placing the result in the receiver.
  625. //
  626. // See the Scaler interface for more information.
  627. func (m *Dense) Scale(f float64, a Matrix) {
  628. ar, ac := a.Dims()
  629. m.reuseAsNonZeroed(ar, ac)
  630. aU, aTrans := untransposeExtract(a)
  631. if rm, ok := aU.(*Dense); ok {
  632. amat := rm.mat
  633. if m == aU || m.checkOverlap(amat) {
  634. var restore func()
  635. m, restore = m.isolatedWorkspace(a)
  636. defer restore()
  637. }
  638. if !aTrans {
  639. for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
  640. for i, v := range amat.Data[ja : ja+ac] {
  641. m.mat.Data[i+jm] = v * f
  642. }
  643. }
  644. } else {
  645. for ja, jm := 0, 0; ja < ac*amat.Stride; ja, jm = ja+amat.Stride, jm+1 {
  646. for i, v := range amat.Data[ja : ja+ar] {
  647. m.mat.Data[i*m.mat.Stride+jm] = v * f
  648. }
  649. }
  650. }
  651. return
  652. }
  653. m.checkOverlapMatrix(a)
  654. for r := 0; r < ar; r++ {
  655. for c := 0; c < ac; c++ {
  656. m.set(r, c, f*a.At(r, c))
  657. }
  658. }
  659. }
  660. // Apply applies the function fn to each of the elements of a, placing the
  661. // resulting matrix in the receiver. The function fn takes a row/column
  662. // index and element value and returns some function of that tuple.
  663. func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
  664. ar, ac := a.Dims()
  665. m.reuseAsNonZeroed(ar, ac)
  666. aU, aTrans := untransposeExtract(a)
  667. if rm, ok := aU.(*Dense); ok {
  668. amat := rm.mat
  669. if m == aU || m.checkOverlap(amat) {
  670. var restore func()
  671. m, restore = m.isolatedWorkspace(a)
  672. defer restore()
  673. }
  674. if !aTrans {
  675. for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
  676. for i, v := range amat.Data[ja : ja+ac] {
  677. m.mat.Data[i+jm] = fn(j, i, v)
  678. }
  679. }
  680. } else {
  681. for j, ja, jm := 0, 0, 0; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+1 {
  682. for i, v := range amat.Data[ja : ja+ar] {
  683. m.mat.Data[i*m.mat.Stride+jm] = fn(i, j, v)
  684. }
  685. }
  686. }
  687. return
  688. }
  689. m.checkOverlapMatrix(a)
  690. for r := 0; r < ar; r++ {
  691. for c := 0; c < ac; c++ {
  692. m.set(r, c, fn(r, c, a.At(r, c)))
  693. }
  694. }
  695. }
  696. // RankOne performs a rank-one update to the matrix a with the vectors x and
  697. // y, where x and y are treated as column vectors. The result is stored in the
  698. // receiver. The Outer method can be used instead of RankOne if a is not needed.
  699. // m = a + alpha * x * yᵀ
  700. func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
  701. ar, ac := a.Dims()
  702. if x.Len() != ar {
  703. panic(ErrShape)
  704. }
  705. if y.Len() != ac {
  706. panic(ErrShape)
  707. }
  708. if a != m {
  709. aU, _ := untransposeExtract(a)
  710. if rm, ok := aU.(*Dense); ok {
  711. m.checkOverlap(rm.RawMatrix())
  712. }
  713. }
  714. var xmat, ymat blas64.Vector
  715. fast := true
  716. xU, _ := untransposeExtract(x)
  717. if rv, ok := xU.(*VecDense); ok {
  718. r, c := xU.Dims()
  719. xmat = rv.mat
  720. m.checkOverlap(generalFromVector(xmat, r, c))
  721. } else {
  722. fast = false
  723. }
  724. yU, _ := untransposeExtract(y)
  725. if rv, ok := yU.(*VecDense); ok {
  726. r, c := yU.Dims()
  727. ymat = rv.mat
  728. m.checkOverlap(generalFromVector(ymat, r, c))
  729. } else {
  730. fast = false
  731. }
  732. if fast {
  733. if m != a {
  734. m.reuseAsNonZeroed(ar, ac)
  735. m.Copy(a)
  736. }
  737. blas64.Ger(alpha, xmat, ymat, m.mat)
  738. return
  739. }
  740. m.reuseAsNonZeroed(ar, ac)
  741. for i := 0; i < ar; i++ {
  742. for j := 0; j < ac; j++ {
  743. m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
  744. }
  745. }
  746. }
  747. // Outer calculates the outer product of the vectors x and y, where x and y
  748. // are treated as column vectors, and stores the result in the receiver.
  749. // m = alpha * x * yᵀ
  750. // In order to update an existing matrix, see RankOne.
  751. func (m *Dense) Outer(alpha float64, x, y Vector) {
  752. r, c := x.Len(), y.Len()
  753. m.reuseAsZeroed(r, c)
  754. var xmat, ymat blas64.Vector
  755. fast := true
  756. xU, _ := untransposeExtract(x)
  757. if rv, ok := xU.(*VecDense); ok {
  758. r, c := xU.Dims()
  759. xmat = rv.mat
  760. m.checkOverlap(generalFromVector(xmat, r, c))
  761. } else {
  762. fast = false
  763. }
  764. yU, _ := untransposeExtract(y)
  765. if rv, ok := yU.(*VecDense); ok {
  766. r, c := yU.Dims()
  767. ymat = rv.mat
  768. m.checkOverlap(generalFromVector(ymat, r, c))
  769. } else {
  770. fast = false
  771. }
  772. if fast {
  773. for i := 0; i < r; i++ {
  774. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
  775. }
  776. blas64.Ger(alpha, xmat, ymat, m.mat)
  777. return
  778. }
  779. for i := 0; i < r; i++ {
  780. for j := 0; j < c; j++ {
  781. m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))
  782. }
  783. }
  784. }