conv.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package blas64
  5. import "gonum.org/v1/gonum/blas"
  6. // GeneralCols represents a matrix using the conventional column-major storage scheme.
  7. type GeneralCols General
  8. // From fills the receiver with elements from a. The receiver
  9. // must have the same dimensions as a and have adequate backing
  10. // data storage.
  11. func (t GeneralCols) From(a General) {
  12. if t.Rows != a.Rows || t.Cols != a.Cols {
  13. panic("blas64: mismatched dimension")
  14. }
  15. if len(t.Data) < (t.Cols-1)*t.Stride+t.Rows {
  16. panic("blas64: short data slice")
  17. }
  18. for i := 0; i < a.Rows; i++ {
  19. for j, v := range a.Data[i*a.Stride : i*a.Stride+a.Cols] {
  20. t.Data[i+j*t.Stride] = v
  21. }
  22. }
  23. }
  24. // From fills the receiver with elements from a. The receiver
  25. // must have the same dimensions as a and have adequate backing
  26. // data storage.
  27. func (t General) From(a GeneralCols) {
  28. if t.Rows != a.Rows || t.Cols != a.Cols {
  29. panic("blas64: mismatched dimension")
  30. }
  31. if len(t.Data) < (t.Rows-1)*t.Stride+t.Cols {
  32. panic("blas64: short data slice")
  33. }
  34. for j := 0; j < a.Cols; j++ {
  35. for i, v := range a.Data[j*a.Stride : j*a.Stride+a.Rows] {
  36. t.Data[i*t.Stride+j] = v
  37. }
  38. }
  39. }
  40. // TriangularCols represents a matrix using the conventional column-major storage scheme.
  41. type TriangularCols Triangular
  42. // From fills the receiver with elements from a. The receiver
  43. // must have the same dimensions, uplo and diag as a and have
  44. // adequate backing data storage.
  45. func (t TriangularCols) From(a Triangular) {
  46. if t.N != a.N {
  47. panic("blas64: mismatched dimension")
  48. }
  49. if t.Uplo != a.Uplo {
  50. panic("blas64: mismatched BLAS uplo")
  51. }
  52. if t.Diag != a.Diag {
  53. panic("blas64: mismatched BLAS diag")
  54. }
  55. switch a.Uplo {
  56. default:
  57. panic("blas64: bad BLAS uplo")
  58. case blas.Upper:
  59. for i := 0; i < a.N; i++ {
  60. for j := i; j < a.N; j++ {
  61. t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
  62. }
  63. }
  64. case blas.Lower:
  65. for i := 0; i < a.N; i++ {
  66. for j := 0; j <= i; j++ {
  67. t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
  68. }
  69. }
  70. case blas.All:
  71. for i := 0; i < a.N; i++ {
  72. for j := 0; j < a.N; j++ {
  73. t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
  74. }
  75. }
  76. }
  77. }
  78. // From fills the receiver with elements from a. The receiver
  79. // must have the same dimensions, uplo and diag as a and have
  80. // adequate backing data storage.
  81. func (t Triangular) From(a TriangularCols) {
  82. if t.N != a.N {
  83. panic("blas64: mismatched dimension")
  84. }
  85. if t.Uplo != a.Uplo {
  86. panic("blas64: mismatched BLAS uplo")
  87. }
  88. if t.Diag != a.Diag {
  89. panic("blas64: mismatched BLAS diag")
  90. }
  91. switch a.Uplo {
  92. default:
  93. panic("blas64: bad BLAS uplo")
  94. case blas.Upper:
  95. for i := 0; i < a.N; i++ {
  96. for j := i; j < a.N; j++ {
  97. t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
  98. }
  99. }
  100. case blas.Lower:
  101. for i := 0; i < a.N; i++ {
  102. for j := 0; j <= i; j++ {
  103. t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
  104. }
  105. }
  106. case blas.All:
  107. for i := 0; i < a.N; i++ {
  108. for j := 0; j < a.N; j++ {
  109. t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
  110. }
  111. }
  112. }
  113. }
  114. // BandCols represents a matrix using the band column-major storage scheme.
  115. type BandCols Band
  116. // From fills the receiver with elements from a. The receiver
  117. // must have the same dimensions and bandwidth as a and have
  118. // adequate backing data storage.
  119. func (t BandCols) From(a Band) {
  120. if t.Rows != a.Rows || t.Cols != a.Cols {
  121. panic("blas64: mismatched dimension")
  122. }
  123. if t.KL != a.KL || t.KU != a.KU {
  124. panic("blas64: mismatched bandwidth")
  125. }
  126. if a.Stride < a.KL+a.KU+1 {
  127. panic("blas64: short stride for source")
  128. }
  129. if t.Stride < t.KL+t.KU+1 {
  130. panic("blas64: short stride for destination")
  131. }
  132. for i := 0; i < a.Rows; i++ {
  133. for j := max(0, i-a.KL); j < min(i+a.KU+1, a.Cols); j++ {
  134. t.Data[i+t.KU-j+j*t.Stride] = a.Data[j+a.KL-i+i*a.Stride]
  135. }
  136. }
  137. }
  138. // From fills the receiver with elements from a. The receiver
  139. // must have the same dimensions and bandwidth as a and have
  140. // adequate backing data storage.
  141. func (t Band) From(a BandCols) {
  142. if t.Rows != a.Rows || t.Cols != a.Cols {
  143. panic("blas64: mismatched dimension")
  144. }
  145. if t.KL != a.KL || t.KU != a.KU {
  146. panic("blas64: mismatched bandwidth")
  147. }
  148. if a.Stride < a.KL+a.KU+1 {
  149. panic("blas64: short stride for source")
  150. }
  151. if t.Stride < t.KL+t.KU+1 {
  152. panic("blas64: short stride for destination")
  153. }
  154. for j := 0; j < a.Cols; j++ {
  155. for i := max(0, j-a.KU); i < min(j+a.KL+1, a.Rows); i++ {
  156. t.Data[j+a.KL-i+i*a.Stride] = a.Data[i+t.KU-j+j*t.Stride]
  157. }
  158. }
  159. }
  160. // TriangularBandCols represents a triangular matrix using the band column-major storage scheme.
  161. type TriangularBandCols TriangularBand
  162. // From fills the receiver with elements from a. The receiver
  163. // must have the same dimensions, bandwidth and uplo as a and
  164. // have adequate backing data storage.
  165. func (t TriangularBandCols) From(a TriangularBand) {
  166. if t.N != a.N {
  167. panic("blas64: mismatched dimension")
  168. }
  169. if t.K != a.K {
  170. panic("blas64: mismatched bandwidth")
  171. }
  172. if a.Stride < a.K+1 {
  173. panic("blas64: short stride for source")
  174. }
  175. if t.Stride < t.K+1 {
  176. panic("blas64: short stride for destination")
  177. }
  178. if t.Uplo != a.Uplo {
  179. panic("blas64: mismatched BLAS uplo")
  180. }
  181. if t.Diag != a.Diag {
  182. panic("blas64: mismatched BLAS diag")
  183. }
  184. dst := BandCols{
  185. Rows: t.N, Cols: t.N,
  186. Stride: t.Stride,
  187. Data: t.Data,
  188. }
  189. src := Band{
  190. Rows: a.N, Cols: a.N,
  191. Stride: a.Stride,
  192. Data: a.Data,
  193. }
  194. switch a.Uplo {
  195. default:
  196. panic("blas64: bad BLAS uplo")
  197. case blas.Upper:
  198. dst.KU = t.K
  199. src.KU = a.K
  200. case blas.Lower:
  201. dst.KL = t.K
  202. src.KL = a.K
  203. }
  204. dst.From(src)
  205. }
  206. // From fills the receiver with elements from a. The receiver
  207. // must have the same dimensions, bandwidth and uplo as a and
  208. // have adequate backing data storage.
  209. func (t TriangularBand) From(a TriangularBandCols) {
  210. if t.N != a.N {
  211. panic("blas64: mismatched dimension")
  212. }
  213. if t.K != a.K {
  214. panic("blas64: mismatched bandwidth")
  215. }
  216. if a.Stride < a.K+1 {
  217. panic("blas64: short stride for source")
  218. }
  219. if t.Stride < t.K+1 {
  220. panic("blas64: short stride for destination")
  221. }
  222. if t.Uplo != a.Uplo {
  223. panic("blas64: mismatched BLAS uplo")
  224. }
  225. if t.Diag != a.Diag {
  226. panic("blas64: mismatched BLAS diag")
  227. }
  228. dst := Band{
  229. Rows: t.N, Cols: t.N,
  230. Stride: t.Stride,
  231. Data: t.Data,
  232. }
  233. src := BandCols{
  234. Rows: a.N, Cols: a.N,
  235. Stride: a.Stride,
  236. Data: a.Data,
  237. }
  238. switch a.Uplo {
  239. default:
  240. panic("blas64: bad BLAS uplo")
  241. case blas.Upper:
  242. dst.KU = t.K
  243. src.KU = a.K
  244. case blas.Lower:
  245. dst.KL = t.K
  246. src.KL = a.K
  247. }
  248. dst.From(src)
  249. }
  250. func min(a, b int) int {
  251. if a < b {
  252. return a
  253. }
  254. return b
  255. }
  256. func max(a, b int) int {
  257. if a > b {
  258. return a
  259. }
  260. return b
  261. }