diagonal.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. // Copyright ©2018 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. )
  9. var (
  10. diagDense *DiagDense
  11. _ Matrix = diagDense
  12. _ allMatrix = diagDense
  13. _ denseMatrix = diagDense
  14. _ Diagonal = diagDense
  15. _ MutableDiagonal = diagDense
  16. _ Triangular = diagDense
  17. _ TriBanded = diagDense
  18. _ Symmetric = diagDense
  19. _ SymBanded = diagDense
  20. _ Banded = diagDense
  21. _ RawBander = diagDense
  22. _ RawSymBander = diagDense
  23. diag Diagonal
  24. _ Matrix = diag
  25. _ Diagonal = diag
  26. _ Triangular = diag
  27. _ TriBanded = diag
  28. _ Symmetric = diag
  29. _ SymBanded = diag
  30. _ Banded = diag
  31. )
  32. // Diagonal represents a diagonal matrix, that is a square matrix that only
  33. // has non-zero terms on the diagonal.
  34. type Diagonal interface {
  35. Matrix
  36. // Diag returns the number of rows/columns in the matrix.
  37. Diag() int
  38. // Bandwidth and TBand are included in the Diagonal interface
  39. // to allow the use of Diagonal types in banded functions.
  40. // Bandwidth will always return (0, 0).
  41. Bandwidth() (kl, ku int)
  42. TBand() Banded
  43. // Triangle and TTri are included in the Diagonal interface
  44. // to allow the use of Diagonal types in triangular functions.
  45. Triangle() (int, TriKind)
  46. TTri() Triangular
  47. // Symmetric and SymBand are included in the Diagonal interface
  48. // to allow the use of Diagonal types in symmetric and banded symmetric
  49. // functions respectively.
  50. Symmetric() int
  51. SymBand() (n, k int)
  52. // TriBand and TTriBand are included in the Diagonal interface
  53. // to allow the use of Diagonal types in triangular banded functions.
  54. TriBand() (n, k int, kind TriKind)
  55. TTriBand() TriBanded
  56. }
  57. // MutableDiagonal is a Diagonal matrix whose elements can be set.
  58. type MutableDiagonal interface {
  59. Diagonal
  60. SetDiag(i int, v float64)
  61. }
  62. // DiagDense represents a diagonal matrix in dense storage format.
  63. type DiagDense struct {
  64. mat blas64.Vector
  65. }
  66. // NewDiagDense creates a new Diagonal matrix with n rows and n columns.
  67. // The length of data must be n or data must be nil, otherwise NewDiagDense
  68. // will panic. NewDiagDense will panic if n is zero.
  69. func NewDiagDense(n int, data []float64) *DiagDense {
  70. if n <= 0 {
  71. if n == 0 {
  72. panic(ErrZeroLength)
  73. }
  74. panic("mat: negative dimension")
  75. }
  76. if data == nil {
  77. data = make([]float64, n)
  78. }
  79. if len(data) != n {
  80. panic(ErrShape)
  81. }
  82. return &DiagDense{
  83. mat: blas64.Vector{N: n, Data: data, Inc: 1},
  84. }
  85. }
  86. // Diag returns the dimension of the receiver.
  87. func (d *DiagDense) Diag() int {
  88. return d.mat.N
  89. }
  90. // Dims returns the dimensions of the matrix.
  91. func (d *DiagDense) Dims() (r, c int) {
  92. return d.mat.N, d.mat.N
  93. }
  94. // T returns the transpose of the matrix.
  95. func (d *DiagDense) T() Matrix {
  96. return d
  97. }
  98. // TTri returns the transpose of the matrix. Note that Diagonal matrices are
  99. // Upper by default.
  100. func (d *DiagDense) TTri() Triangular {
  101. return TransposeTri{d}
  102. }
  103. // TBand performs an implicit transpose by returning the receiver inside a
  104. // TransposeBand.
  105. func (d *DiagDense) TBand() Banded {
  106. return TransposeBand{d}
  107. }
  108. // TTriBand performs an implicit transpose by returning the receiver inside a
  109. // TransposeTriBand. Note that Diagonal matrices are Upper by default.
  110. func (d *DiagDense) TTriBand() TriBanded {
  111. return TransposeTriBand{d}
  112. }
  113. // Bandwidth returns the upper and lower bandwidths of the matrix.
  114. // These values are always zero for diagonal matrices.
  115. func (d *DiagDense) Bandwidth() (kl, ku int) {
  116. return 0, 0
  117. }
  118. // Symmetric implements the Symmetric interface.
  119. func (d *DiagDense) Symmetric() int {
  120. return d.mat.N
  121. }
  122. // SymBand returns the number of rows/columns in the matrix, and the size of
  123. // the bandwidth.
  124. func (d *DiagDense) SymBand() (n, k int) {
  125. return d.mat.N, 0
  126. }
  127. // Triangle implements the Triangular interface.
  128. func (d *DiagDense) Triangle() (int, TriKind) {
  129. return d.mat.N, Upper
  130. }
  131. // TriBand returns the number of rows/columns in the matrix, the
  132. // size of the bandwidth, and the orientation. Note that Diagonal matrices are
  133. // Upper by default.
  134. func (d *DiagDense) TriBand() (n, k int, kind TriKind) {
  135. return d.mat.N, 0, Upper
  136. }
  137. // Reset empties the matrix so that it can be reused as the
  138. // receiver of a dimensionally restricted operation.
  139. //
  140. // Reset should not be used when the matrix shares backing data.
  141. // See the Reseter interface for more information.
  142. func (d *DiagDense) Reset() {
  143. // No change of Inc or n to 0 may be
  144. // made unless both are set to 0.
  145. d.mat.Inc = 0
  146. d.mat.N = 0
  147. d.mat.Data = d.mat.Data[:0]
  148. }
  149. // Zero sets all of the matrix elements to zero.
  150. func (d *DiagDense) Zero() {
  151. for i := 0; i < d.mat.N; i++ {
  152. d.mat.Data[d.mat.Inc*i] = 0
  153. }
  154. }
  155. // DiagView returns the diagonal as a matrix backed by the original data.
  156. func (d *DiagDense) DiagView() Diagonal {
  157. return d
  158. }
  159. // DiagFrom copies the diagonal of m into the receiver. The receiver must
  160. // be min(r, c) long or empty, otherwise DiagFrom will panic.
  161. func (d *DiagDense) DiagFrom(m Matrix) {
  162. n := min(m.Dims())
  163. d.reuseAsNonZeroed(n)
  164. var vec blas64.Vector
  165. switch r := m.(type) {
  166. case *DiagDense:
  167. vec = r.mat
  168. case RawBander:
  169. mat := r.RawBand()
  170. vec = blas64.Vector{
  171. N: n,
  172. Inc: mat.Stride,
  173. Data: mat.Data[mat.KL : (n-1)*mat.Stride+mat.KL+1],
  174. }
  175. case RawMatrixer:
  176. mat := r.RawMatrix()
  177. vec = blas64.Vector{
  178. N: n,
  179. Inc: mat.Stride + 1,
  180. Data: mat.Data[:(n-1)*mat.Stride+n],
  181. }
  182. case RawSymBander:
  183. mat := r.RawSymBand()
  184. vec = blas64.Vector{
  185. N: n,
  186. Inc: mat.Stride,
  187. Data: mat.Data[:(n-1)*mat.Stride+1],
  188. }
  189. case RawSymmetricer:
  190. mat := r.RawSymmetric()
  191. vec = blas64.Vector{
  192. N: n,
  193. Inc: mat.Stride + 1,
  194. Data: mat.Data[:(n-1)*mat.Stride+n],
  195. }
  196. case RawTriBander:
  197. mat := r.RawTriBand()
  198. data := mat.Data
  199. if mat.Uplo == blas.Lower {
  200. data = data[mat.K:]
  201. }
  202. vec = blas64.Vector{
  203. N: n,
  204. Inc: mat.Stride,
  205. Data: data[:(n-1)*mat.Stride+1],
  206. }
  207. case RawTriangular:
  208. mat := r.RawTriangular()
  209. if mat.Diag == blas.Unit {
  210. for i := 0; i < n; i += d.mat.Inc {
  211. d.mat.Data[i] = 1
  212. }
  213. return
  214. }
  215. vec = blas64.Vector{
  216. N: n,
  217. Inc: mat.Stride + 1,
  218. Data: mat.Data[:(n-1)*mat.Stride+n],
  219. }
  220. case RawVectorer:
  221. d.mat.Data[0] = r.RawVector().Data[0]
  222. return
  223. default:
  224. for i := 0; i < n; i++ {
  225. d.setDiag(i, m.At(i, i))
  226. }
  227. return
  228. }
  229. blas64.Copy(vec, d.mat)
  230. }
  231. // RawBand returns the underlying data used by the receiver represented
  232. // as a blas64.Band.
  233. // Changes to elements in the receiver following the call will be reflected
  234. // in returned blas64.Band.
  235. func (d *DiagDense) RawBand() blas64.Band {
  236. return blas64.Band{
  237. Rows: d.mat.N,
  238. Cols: d.mat.N,
  239. KL: 0,
  240. KU: 0,
  241. Stride: d.mat.Inc,
  242. Data: d.mat.Data,
  243. }
  244. }
  245. // RawSymBand returns the underlying data used by the receiver represented
  246. // as a blas64.SymmetricBand.
  247. // Changes to elements in the receiver following the call will be reflected
  248. // in returned blas64.Band.
  249. func (d *DiagDense) RawSymBand() blas64.SymmetricBand {
  250. return blas64.SymmetricBand{
  251. N: d.mat.N,
  252. K: 0,
  253. Stride: d.mat.Inc,
  254. Uplo: blas.Upper,
  255. Data: d.mat.Data,
  256. }
  257. }
  258. // reuseAsNonZeroed resizes an empty diagonal to a r×r diagonal,
  259. // or checks that a non-empty matrix is r×r.
  260. func (d *DiagDense) reuseAsNonZeroed(r int) {
  261. if r == 0 {
  262. panic(ErrZeroLength)
  263. }
  264. if d.IsEmpty() {
  265. d.mat = blas64.Vector{
  266. Inc: 1,
  267. Data: use(d.mat.Data, r),
  268. }
  269. d.mat.N = r
  270. return
  271. }
  272. if r != d.mat.N {
  273. panic(ErrShape)
  274. }
  275. }
  276. // IsEmpty returns whether the receiver is empty. Empty matrices can be the
  277. // receiver for size-restricted operations. The receiver can be emptied using
  278. // Reset.
  279. func (d *DiagDense) IsEmpty() bool {
  280. // It must be the case that d.Dims() returns
  281. // zeros in this case. See comment in Reset().
  282. return d.mat.Inc == 0
  283. }
  284. // Trace returns the trace.
  285. func (d *DiagDense) Trace() float64 {
  286. rb := d.RawBand()
  287. var tr float64
  288. for i := 0; i < rb.Rows; i++ {
  289. tr += rb.Data[rb.KL+i*rb.Stride]
  290. }
  291. return tr
  292. }