blas64.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package blas64
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/gonum"
  8. )
  9. var blas64 blas.Float64 = gonum.Implementation{}
  10. // Use sets the BLAS float64 implementation to be used by subsequent BLAS calls.
  11. // The default implementation is
  12. // gonum.org/v1/gonum/blas/gonum.Implementation.
  13. func Use(b blas.Float64) {
  14. blas64 = b
  15. }
  16. // Implementation returns the current BLAS float64 implementation.
  17. //
  18. // Implementation allows direct calls to the current the BLAS float64 implementation
  19. // giving finer control of parameters.
  20. func Implementation() blas.Float64 {
  21. return blas64
  22. }
  23. // Vector represents a vector with an associated element increment.
  24. type Vector struct {
  25. N int
  26. Data []float64
  27. Inc int
  28. }
  29. // General represents a matrix using the conventional storage scheme.
  30. type General struct {
  31. Rows, Cols int
  32. Data []float64
  33. Stride int
  34. }
  35. // Band represents a band matrix using the band storage scheme.
  36. type Band struct {
  37. Rows, Cols int
  38. KL, KU int
  39. Data []float64
  40. Stride int
  41. }
  42. // Triangular represents a triangular matrix using the conventional storage scheme.
  43. type Triangular struct {
  44. Uplo blas.Uplo
  45. Diag blas.Diag
  46. N int
  47. Data []float64
  48. Stride int
  49. }
  50. // TriangularBand represents a triangular matrix using the band storage scheme.
  51. type TriangularBand struct {
  52. Uplo blas.Uplo
  53. Diag blas.Diag
  54. N, K int
  55. Data []float64
  56. Stride int
  57. }
  58. // TriangularPacked represents a triangular matrix using the packed storage scheme.
  59. type TriangularPacked struct {
  60. Uplo blas.Uplo
  61. Diag blas.Diag
  62. N int
  63. Data []float64
  64. }
  65. // Symmetric represents a symmetric matrix using the conventional storage scheme.
  66. type Symmetric struct {
  67. Uplo blas.Uplo
  68. N int
  69. Data []float64
  70. Stride int
  71. }
  72. // SymmetricBand represents a symmetric matrix using the band storage scheme.
  73. type SymmetricBand struct {
  74. Uplo blas.Uplo
  75. N, K int
  76. Data []float64
  77. Stride int
  78. }
  79. // SymmetricPacked represents a symmetric matrix using the packed storage scheme.
  80. type SymmetricPacked struct {
  81. Uplo blas.Uplo
  82. N int
  83. Data []float64
  84. }
  85. // Level 1
  86. const (
  87. negInc = "blas64: negative vector increment"
  88. badLength = "blas64: vector length mismatch"
  89. )
  90. // Dot computes the dot product of the two vectors:
  91. // \sum_i x[i]*y[i].
  92. // Dot will panic if the lengths of x and y do not match.
  93. func Dot(x, y Vector) float64 {
  94. if x.N != y.N {
  95. panic(badLength)
  96. }
  97. return blas64.Ddot(x.N, x.Data, x.Inc, y.Data, y.Inc)
  98. }
  99. // Nrm2 computes the Euclidean norm of the vector x:
  100. // sqrt(\sum_i x[i]*x[i]).
  101. //
  102. // Nrm2 will panic if the vector increment is negative.
  103. func Nrm2(x Vector) float64 {
  104. if x.Inc < 0 {
  105. panic(negInc)
  106. }
  107. return blas64.Dnrm2(x.N, x.Data, x.Inc)
  108. }
  109. // Asum computes the sum of the absolute values of the elements of x:
  110. // \sum_i |x[i]|.
  111. //
  112. // Asum will panic if the vector increment is negative.
  113. func Asum(x Vector) float64 {
  114. if x.Inc < 0 {
  115. panic(negInc)
  116. }
  117. return blas64.Dasum(x.N, x.Data, x.Inc)
  118. }
  119. // Iamax returns the index of an element of x with the largest absolute value.
  120. // If there are multiple such indices the earliest is returned.
  121. // Iamax returns -1 if n == 0.
  122. //
  123. // Iamax will panic if the vector increment is negative.
  124. func Iamax(x Vector) int {
  125. if x.Inc < 0 {
  126. panic(negInc)
  127. }
  128. return blas64.Idamax(x.N, x.Data, x.Inc)
  129. }
  130. // Swap exchanges the elements of the two vectors:
  131. // x[i], y[i] = y[i], x[i] for all i.
  132. // Swap will panic if the lengths of x and y do not match.
  133. func Swap(x, y Vector) {
  134. if x.N != y.N {
  135. panic(badLength)
  136. }
  137. blas64.Dswap(x.N, x.Data, x.Inc, y.Data, y.Inc)
  138. }
  139. // Copy copies the elements of x into the elements of y:
  140. // y[i] = x[i] for all i.
  141. // Copy will panic if the lengths of x and y do not match.
  142. func Copy(x, y Vector) {
  143. if x.N != y.N {
  144. panic(badLength)
  145. }
  146. blas64.Dcopy(x.N, x.Data, x.Inc, y.Data, y.Inc)
  147. }
  148. // Axpy adds x scaled by alpha to y:
  149. // y[i] += alpha*x[i] for all i.
  150. // Axpy will panic if the lengths of x and y do not match.
  151. func Axpy(alpha float64, x, y Vector) {
  152. if x.N != y.N {
  153. panic(badLength)
  154. }
  155. blas64.Daxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
  156. }
  157. // Rotg computes the parameters of a Givens plane rotation so that
  158. // ⎡ c s⎤ ⎡a⎤ ⎡r⎤
  159. // ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
  160. // where a and b are the Cartesian coordinates of a given point.
  161. // c, s, and r are defined as
  162. // r = ±Sqrt(a^2 + b^2),
  163. // c = a/r, the cosine of the rotation angle,
  164. // s = a/r, the sine of the rotation angle,
  165. // and z is defined such that
  166. // if |a| > |b|, z = s,
  167. // otherwise if c != 0, z = 1/c,
  168. // otherwise z = 1.
  169. func Rotg(a, b float64) (c, s, r, z float64) {
  170. return blas64.Drotg(a, b)
  171. }
  172. // Rotmg computes the modified Givens rotation. See
  173. // http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
  174. // for more details.
  175. func Rotmg(d1, d2, b1, b2 float64) (p blas.DrotmParams, rd1, rd2, rb1 float64) {
  176. return blas64.Drotmg(d1, d2, b1, b2)
  177. }
  178. // Rot applies a plane transformation to n points represented by the vectors x
  179. // and y:
  180. // x[i] = c*x[i] + s*y[i],
  181. // y[i] = -s*x[i] + c*y[i], for all i.
  182. func Rot(x, y Vector, c, s float64) {
  183. if x.N != y.N {
  184. panic(badLength)
  185. }
  186. blas64.Drot(x.N, x.Data, x.Inc, y.Data, y.Inc, c, s)
  187. }
  188. // Rotm applies the modified Givens rotation to n points represented by the
  189. // vectors x and y.
  190. func Rotm(x, y Vector, p blas.DrotmParams) {
  191. if x.N != y.N {
  192. panic(badLength)
  193. }
  194. blas64.Drotm(x.N, x.Data, x.Inc, y.Data, y.Inc, p)
  195. }
  196. // Scal scales the vector x by alpha:
  197. // x[i] *= alpha for all i.
  198. //
  199. // Scal will panic if the vector increment is negative.
  200. func Scal(alpha float64, x Vector) {
  201. if x.Inc < 0 {
  202. panic(negInc)
  203. }
  204. blas64.Dscal(x.N, alpha, x.Data, x.Inc)
  205. }
  206. // Level 2
  207. // Gemv computes
  208. // y = alpha * A * x + beta * y if t == blas.NoTrans,
  209. // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans,
  210. // where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
  211. func Gemv(t blas.Transpose, alpha float64, a General, x Vector, beta float64, y Vector) {
  212. blas64.Dgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
  213. }
  214. // Gbmv computes
  215. // y = alpha * A * x + beta * y if t == blas.NoTrans,
  216. // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans,
  217. // where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
  218. func Gbmv(t blas.Transpose, alpha float64, a Band, x Vector, beta float64, y Vector) {
  219. blas64.Dgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
  220. }
  221. // Trmv computes
  222. // x = A * x if t == blas.NoTrans,
  223. // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans,
  224. // where A is an n×n triangular matrix, and x is a vector.
  225. func Trmv(t blas.Transpose, a Triangular, x Vector) {
  226. blas64.Dtrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
  227. }
  228. // Tbmv computes
  229. // x = A * x if t == blas.NoTrans,
  230. // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans,
  231. // where A is an n×n triangular band matrix, and x is a vector.
  232. func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
  233. blas64.Dtbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
  234. }
  235. // Tpmv computes
  236. // x = A * x if t == blas.NoTrans,
  237. // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans,
  238. // where A is an n×n triangular matrix in packed format, and x is a vector.
  239. func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
  240. blas64.Dtpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
  241. }
  242. // Trsv solves
  243. // A * x = b if t == blas.NoTrans,
  244. // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans,
  245. // where A is an n×n triangular matrix, and x and b are vectors.
  246. //
  247. // At entry to the function, x contains the values of b, and the result is
  248. // stored in-place into x.
  249. //
  250. // No test for singularity or near-singularity is included in this
  251. // routine. Such tests must be performed before calling this routine.
  252. func Trsv(t blas.Transpose, a Triangular, x Vector) {
  253. blas64.Dtrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
  254. }
  255. // Tbsv solves
  256. // A * x = b if t == blas.NoTrans,
  257. // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans,
  258. // where A is an n×n triangular band matrix, and x and b are vectors.
  259. //
  260. // At entry to the function, x contains the values of b, and the result is
  261. // stored in place into x.
  262. //
  263. // No test for singularity or near-singularity is included in this
  264. // routine. Such tests must be performed before calling this routine.
  265. func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
  266. blas64.Dtbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
  267. }
  268. // Tpsv solves
  269. // A * x = b if t == blas.NoTrans,
  270. // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans,
  271. // where A is an n×n triangular matrix in packed format, and x and b are
  272. // vectors.
  273. //
  274. // At entry to the function, x contains the values of b, and the result is
  275. // stored in place into x.
  276. //
  277. // No test for singularity or near-singularity is included in this
  278. // routine. Such tests must be performed before calling this routine.
  279. func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
  280. blas64.Dtpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
  281. }
  282. // Symv computes
  283. // y = alpha * A * x + beta * y,
  284. // where A is an n×n symmetric matrix, x and y are vectors, and alpha and
  285. // beta are scalars.
  286. func Symv(alpha float64, a Symmetric, x Vector, beta float64, y Vector) {
  287. blas64.Dsymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
  288. }
  289. // Sbmv performs
  290. // y = alpha * A * x + beta * y,
  291. // where A is an n×n symmetric band matrix, x and y are vectors, and alpha
  292. // and beta are scalars.
  293. func Sbmv(alpha float64, a SymmetricBand, x Vector, beta float64, y Vector) {
  294. blas64.Dsbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
  295. }
  296. // Spmv performs
  297. // y = alpha * A * x + beta * y,
  298. // where A is an n×n symmetric matrix in packed format, x and y are vectors,
  299. // and alpha and beta are scalars.
  300. func Spmv(alpha float64, a SymmetricPacked, x Vector, beta float64, y Vector) {
  301. blas64.Dspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
  302. }
  303. // Ger performs a rank-1 update
  304. // A += alpha * x * yᵀ,
  305. // where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
  306. func Ger(alpha float64, x, y Vector, a General) {
  307. blas64.Dger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
  308. }
  309. // Syr performs a rank-1 update
  310. // A += alpha * x * xᵀ,
  311. // where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
  312. func Syr(alpha float64, x Vector, a Symmetric) {
  313. blas64.Dsyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
  314. }
  315. // Spr performs the rank-1 update
  316. // A += alpha * x * xᵀ,
  317. // where A is an n×n symmetric matrix in packed format, x is a vector, and
  318. // alpha is a scalar.
  319. func Spr(alpha float64, x Vector, a SymmetricPacked) {
  320. blas64.Dspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
  321. }
  322. // Syr2 performs a rank-2 update
  323. // A += alpha * x * yᵀ + alpha * y * xᵀ,
  324. // where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
  325. func Syr2(alpha float64, x, y Vector, a Symmetric) {
  326. blas64.Dsyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
  327. }
  328. // Spr2 performs a rank-2 update
  329. // A += alpha * x * yᵀ + alpha * y * xᵀ,
  330. // where A is an n×n symmetric matrix in packed format, x and y are vectors,
  331. // and alpha is a scalar.
  332. func Spr2(alpha float64, x, y Vector, a SymmetricPacked) {
  333. blas64.Dspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
  334. }
  335. // Level 3
  336. // Gemm computes
  337. // C = alpha * A * B + beta * C,
  338. // where A, B, and C are dense matrices, and alpha and beta are scalars.
  339. // tA and tB specify whether A or B are transposed.
  340. func Gemm(tA, tB blas.Transpose, alpha float64, a, b General, beta float64, c General) {
  341. var m, n, k int
  342. if tA == blas.NoTrans {
  343. m, k = a.Rows, a.Cols
  344. } else {
  345. m, k = a.Cols, a.Rows
  346. }
  347. if tB == blas.NoTrans {
  348. n = b.Cols
  349. } else {
  350. n = b.Rows
  351. }
  352. blas64.Dgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
  353. }
  354. // Symm performs
  355. // C = alpha * A * B + beta * C if s == blas.Left,
  356. // C = alpha * B * A + beta * C if s == blas.Right,
  357. // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
  358. // alpha is a scalar.
  359. func Symm(s blas.Side, alpha float64, a Symmetric, b General, beta float64, c General) {
  360. var m, n int
  361. if s == blas.Left {
  362. m, n = a.N, b.Cols
  363. } else {
  364. m, n = b.Rows, a.N
  365. }
  366. blas64.Dsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
  367. }
  368. // Syrk performs a symmetric rank-k update
  369. // C = alpha * A * Aᵀ + beta * C if t == blas.NoTrans,
  370. // C = alpha * Aᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans,
  371. // where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
  372. // a k×n matrix otherwise, and alpha and beta are scalars.
  373. func Syrk(t blas.Transpose, alpha float64, a General, beta float64, c Symmetric) {
  374. var n, k int
  375. if t == blas.NoTrans {
  376. n, k = a.Rows, a.Cols
  377. } else {
  378. n, k = a.Cols, a.Rows
  379. }
  380. blas64.Dsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
  381. }
  382. // Syr2k performs a symmetric rank-2k update
  383. // C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C if t == blas.NoTrans,
  384. // C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans,
  385. // where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
  386. // and k×n matrices otherwise, and alpha and beta are scalars.
  387. func Syr2k(t blas.Transpose, alpha float64, a, b General, beta float64, c Symmetric) {
  388. var n, k int
  389. if t == blas.NoTrans {
  390. n, k = a.Rows, a.Cols
  391. } else {
  392. n, k = a.Cols, a.Rows
  393. }
  394. blas64.Dsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
  395. }
  396. // Trmm performs
  397. // B = alpha * A * B if tA == blas.NoTrans and s == blas.Left,
  398. // B = alpha * Aᵀ * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
  399. // B = alpha * B * A if tA == blas.NoTrans and s == blas.Right,
  400. // B = alpha * B * Aᵀ if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
  401. // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
  402. // a scalar.
  403. func Trmm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
  404. blas64.Dtrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
  405. }
  406. // Trsm solves
  407. // A * X = alpha * B if tA == blas.NoTrans and s == blas.Left,
  408. // Aᵀ * X = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
  409. // X * A = alpha * B if tA == blas.NoTrans and s == blas.Right,
  410. // X * Aᵀ = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
  411. // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
  412. // alpha is a scalar.
  413. //
  414. // At entry to the function, X contains the values of B, and the result is
  415. // stored in-place into X.
  416. //
  417. // No check is made that A is invertible.
  418. func Trsm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
  419. blas64.Dtrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
  420. }