symband.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. // Copyright ©2017 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. )
  9. var (
  10. symBandDense *SymBandDense
  11. _ Matrix = symBandDense
  12. _ allMatrix = symBandDense
  13. _ denseMatrix = symBandDense
  14. _ Symmetric = symBandDense
  15. _ Banded = symBandDense
  16. _ SymBanded = symBandDense
  17. _ RawSymBander = symBandDense
  18. _ MutableSymBanded = symBandDense
  19. _ NonZeroDoer = symBandDense
  20. _ RowNonZeroDoer = symBandDense
  21. _ ColNonZeroDoer = symBandDense
  22. )
  23. // SymBandDense represents a symmetric band matrix in dense storage format.
  24. type SymBandDense struct {
  25. mat blas64.SymmetricBand
  26. }
  27. // SymBanded is a symmetric band matrix interface type.
  28. type SymBanded interface {
  29. Banded
  30. // Symmetric returns the number of rows/columns in the matrix.
  31. Symmetric() int
  32. // SymBand returns the number of rows/columns in the matrix, and the size of
  33. // the bandwidth.
  34. SymBand() (n, k int)
  35. }
  36. // MutableSymBanded is a symmetric band matrix interface type that allows elements
  37. // to be altered.
  38. type MutableSymBanded interface {
  39. SymBanded
  40. SetSymBand(i, j int, v float64)
  41. }
  42. // A RawSymBander can return a blas64.SymmetricBand representation of the receiver.
  43. // Changes to the blas64.SymmetricBand.Data slice will be reflected in the original
  44. // matrix, changes to the N, K, Stride and Uplo fields will not.
  45. type RawSymBander interface {
  46. RawSymBand() blas64.SymmetricBand
  47. }
  48. // NewSymBandDense creates a new SymBand matrix with n rows and columns. If data == nil,
  49. // a new slice is allocated for the backing slice. If len(data) == n*(k+1),
  50. // data is used as the backing slice, and changes to the elements of the returned
  51. // SymBandDense will be reflected in data. If neither of these is true, NewSymBandDense
  52. // will panic. k must be at least zero and less than n, otherwise NewSymBandDense will panic.
  53. //
  54. // The data must be arranged in row-major order constructed by removing the zeros
  55. // from the rows outside the band and aligning the diagonals. SymBandDense matrices
  56. // are stored in the upper triangle. For example, the matrix
  57. // 1 2 3 0 0 0
  58. // 2 4 5 6 0 0
  59. // 3 5 7 8 9 0
  60. // 0 6 8 10 11 12
  61. // 0 0 9 11 13 14
  62. // 0 0 0 12 14 15
  63. // becomes (* entries are never accessed)
  64. // 1 2 3
  65. // 4 5 6
  66. // 7 8 9
  67. // 10 11 12
  68. // 13 14 *
  69. // 15 * *
  70. // which is passed to NewSymBandDense as []float64{1, 2, ..., 15, *, *, *} with k=2.
  71. // Only the values in the band portion of the matrix are used.
  72. func NewSymBandDense(n, k int, data []float64) *SymBandDense {
  73. if n <= 0 || k < 0 {
  74. if n == 0 {
  75. panic(ErrZeroLength)
  76. }
  77. panic("mat: negative dimension")
  78. }
  79. if k+1 > n {
  80. panic("mat: band out of range")
  81. }
  82. bc := k + 1
  83. if data != nil && len(data) != n*bc {
  84. panic(ErrShape)
  85. }
  86. if data == nil {
  87. data = make([]float64, n*bc)
  88. }
  89. return &SymBandDense{
  90. mat: blas64.SymmetricBand{
  91. N: n,
  92. K: k,
  93. Stride: bc,
  94. Uplo: blas.Upper,
  95. Data: data,
  96. },
  97. }
  98. }
  99. // Dims returns the number of rows and columns in the matrix.
  100. func (s *SymBandDense) Dims() (r, c int) {
  101. return s.mat.N, s.mat.N
  102. }
  103. // Symmetric returns the size of the receiver.
  104. func (s *SymBandDense) Symmetric() int {
  105. return s.mat.N
  106. }
  107. // Bandwidth returns the bandwidths of the matrix.
  108. func (s *SymBandDense) Bandwidth() (kl, ku int) {
  109. return s.mat.K, s.mat.K
  110. }
  111. // SymBand returns the number of rows/columns in the matrix, and the size of
  112. // the bandwidth.
  113. func (s *SymBandDense) SymBand() (n, k int) {
  114. return s.mat.N, s.mat.K
  115. }
  116. // T implements the Matrix interface. Symmetric matrices, by definition, are
  117. // equal to their transpose, and this is a no-op.
  118. func (s *SymBandDense) T() Matrix {
  119. return s
  120. }
  121. // TBand implements the Banded interface.
  122. func (s *SymBandDense) TBand() Banded {
  123. return s
  124. }
  125. // RawSymBand returns the underlying blas64.SymBand used by the receiver.
  126. // Changes to elements in the receiver following the call will be reflected
  127. // in returned blas64.SymBand.
  128. func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
  129. return s.mat
  130. }
  131. // SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
  132. // Changes to elements in the receiver following the call will be reflected
  133. // in the input.
  134. //
  135. // The supplied SymmetricBand must use blas.Upper storage format.
  136. func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
  137. if mat.Uplo != blas.Upper {
  138. panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
  139. }
  140. s.mat = mat
  141. }
  142. // IsEmpty returns whether the receiver is empty. Empty matrices can be the
  143. // receiver for size-restricted operations. The receiver can be emptied using
  144. // Reset.
  145. func (s *SymBandDense) IsEmpty() bool {
  146. return s.mat.Stride == 0
  147. }
  148. // Reset empties the matrix so that it can be reused as the
  149. // receiver of a dimensionally restricted operation.
  150. //
  151. // Reset should not be used when the matrix shares backing data.
  152. // See the Reseter interface for more information.
  153. func (s *SymBandDense) Reset() {
  154. s.mat.N = 0
  155. s.mat.K = 0
  156. s.mat.Stride = 0
  157. s.mat.Uplo = 0
  158. s.mat.Data = s.mat.Data[:0:0]
  159. }
  160. // Zero sets all of the matrix elements to zero.
  161. func (s *SymBandDense) Zero() {
  162. for i := 0; i < s.mat.N; i++ {
  163. u := min(1+s.mat.K, s.mat.N-i)
  164. zero(s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+u])
  165. }
  166. }
  167. // DiagView returns the diagonal as a matrix backed by the original data.
  168. func (s *SymBandDense) DiagView() Diagonal {
  169. n := s.mat.N
  170. return &DiagDense{
  171. mat: blas64.Vector{
  172. N: n,
  173. Inc: s.mat.Stride,
  174. Data: s.mat.Data[:(n-1)*s.mat.Stride+1],
  175. },
  176. }
  177. }
  178. // DoNonZero calls the function fn for each of the non-zero elements of s. The function fn
  179. // takes a row/column index and the element value of s at (i, j).
  180. func (s *SymBandDense) DoNonZero(fn func(i, j int, v float64)) {
  181. for i := 0; i < s.mat.N; i++ {
  182. for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
  183. v := s.at(i, j)
  184. if v != 0 {
  185. fn(i, j, v)
  186. }
  187. }
  188. }
  189. }
  190. // DoRowNonZero calls the function fn for each of the non-zero elements of row i of s. The function fn
  191. // takes a row/column index and the element value of s at (i, j).
  192. func (s *SymBandDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
  193. if i < 0 || s.mat.N <= i {
  194. panic(ErrRowAccess)
  195. }
  196. for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
  197. v := s.at(i, j)
  198. if v != 0 {
  199. fn(i, j, v)
  200. }
  201. }
  202. }
  203. // DoColNonZero calls the function fn for each of the non-zero elements of column j of s. The function fn
  204. // takes a row/column index and the element value of s at (i, j).
  205. func (s *SymBandDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
  206. if j < 0 || s.mat.N <= j {
  207. panic(ErrColAccess)
  208. }
  209. for i := 0; i < s.mat.N; i++ {
  210. if i-s.mat.K <= j && j < i+s.mat.K+1 {
  211. v := s.at(i, j)
  212. if v != 0 {
  213. fn(i, j, v)
  214. }
  215. }
  216. }
  217. }
  218. // Trace returns the trace.
  219. func (s *SymBandDense) Trace() float64 {
  220. rb := s.RawSymBand()
  221. var tr float64
  222. for i := 0; i < rb.N; i++ {
  223. tr += rb.Data[i*rb.Stride]
  224. }
  225. return tr
  226. }
  227. // MulVecTo computes S⋅x storing the result into dst.
  228. func (s *SymBandDense) MulVecTo(dst *VecDense, _ bool, x Vector) {
  229. n := s.mat.N
  230. if x.Len() != n {
  231. panic(ErrShape)
  232. }
  233. dst.reuseAsNonZeroed(n)
  234. xMat, _ := untransposeExtract(x)
  235. if xVec, ok := xMat.(*VecDense); ok {
  236. if dst != xVec {
  237. dst.checkOverlap(xVec.mat)
  238. blas64.Sbmv(1, s.mat, xVec.mat, 0, dst.mat)
  239. } else {
  240. xCopy := getWorkspaceVec(n, false)
  241. xCopy.CloneFromVec(xVec)
  242. blas64.Sbmv(1, s.mat, xCopy.mat, 0, dst.mat)
  243. putWorkspaceVec(xCopy)
  244. }
  245. } else {
  246. xCopy := getWorkspaceVec(n, false)
  247. xCopy.CloneFromVec(x)
  248. blas64.Sbmv(1, s.mat, xCopy.mat, 0, dst.mat)
  249. putWorkspaceVec(xCopy)
  250. }
  251. }