shadow.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import "gonum.org/v1/gonum/blas/blas64"
  6. // checkOverlap returns false if the receiver does not overlap data elements
  7. // referenced by the parameter and panics otherwise.
  8. //
  9. // checkOverlap methods return a boolean to allow the check call to be added to a
  10. // boolean expression, making use of short-circuit operators.
  11. func checkOverlap(a, b blas64.General) bool {
  12. if cap(a.Data) == 0 || cap(b.Data) == 0 {
  13. return false
  14. }
  15. off := offset(a.Data[:1], b.Data[:1])
  16. if off == 0 {
  17. // At least one element overlaps.
  18. if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride {
  19. panic(regionIdentity)
  20. }
  21. panic(regionOverlap)
  22. }
  23. if off > 0 && len(a.Data) <= off {
  24. // We know a is completely before b.
  25. return false
  26. }
  27. if off < 0 && len(b.Data) <= -off {
  28. // We know a is completely after b.
  29. return false
  30. }
  31. if a.Stride != b.Stride && a.Stride != 1 && b.Stride != 1 {
  32. // Too hard, so assume the worst; if either stride
  33. // is one it will be caught in rectanglesOverlap.
  34. panic(mismatchedStrides)
  35. }
  36. if off < 0 {
  37. off = -off
  38. a.Cols, b.Cols = b.Cols, a.Cols
  39. }
  40. if rectanglesOverlap(off, a.Cols, b.Cols, min(a.Stride, b.Stride)) {
  41. panic(regionOverlap)
  42. }
  43. return false
  44. }
  45. func (m *Dense) checkOverlap(a blas64.General) bool {
  46. return checkOverlap(m.RawMatrix(), a)
  47. }
  48. func (m *Dense) checkOverlapMatrix(a Matrix) bool {
  49. if m == a {
  50. return false
  51. }
  52. var amat blas64.General
  53. switch ar := a.(type) {
  54. default:
  55. return false
  56. case RawMatrixer:
  57. amat = ar.RawMatrix()
  58. case RawSymmetricer:
  59. amat = generalFromSymmetric(ar.RawSymmetric())
  60. case RawSymBander:
  61. amat = generalFromSymmetricBand(ar.RawSymBand())
  62. case RawTriangular:
  63. amat = generalFromTriangular(ar.RawTriangular())
  64. case RawVectorer:
  65. r, c := a.Dims()
  66. amat = generalFromVector(ar.RawVector(), r, c)
  67. }
  68. return m.checkOverlap(amat)
  69. }
  70. func (s *SymDense) checkOverlap(a blas64.General) bool {
  71. return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
  72. }
  73. func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
  74. if s == a {
  75. return false
  76. }
  77. var amat blas64.General
  78. switch ar := a.(type) {
  79. default:
  80. return false
  81. case RawMatrixer:
  82. amat = ar.RawMatrix()
  83. case RawSymmetricer:
  84. amat = generalFromSymmetric(ar.RawSymmetric())
  85. case RawSymBander:
  86. amat = generalFromSymmetricBand(ar.RawSymBand())
  87. case RawTriangular:
  88. amat = generalFromTriangular(ar.RawTriangular())
  89. case RawVectorer:
  90. r, c := a.Dims()
  91. amat = generalFromVector(ar.RawVector(), r, c)
  92. }
  93. return s.checkOverlap(amat)
  94. }
  95. // generalFromSymmetric returns a blas64.General with the backing
  96. // data and dimensions of a.
  97. func generalFromSymmetric(a blas64.Symmetric) blas64.General {
  98. return blas64.General{
  99. Rows: a.N,
  100. Cols: a.N,
  101. Stride: a.Stride,
  102. Data: a.Data,
  103. }
  104. }
  105. func (t *TriDense) checkOverlap(a blas64.General) bool {
  106. return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
  107. }
  108. func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
  109. if t == a {
  110. return false
  111. }
  112. var amat blas64.General
  113. switch ar := a.(type) {
  114. default:
  115. return false
  116. case RawMatrixer:
  117. amat = ar.RawMatrix()
  118. case RawSymmetricer:
  119. amat = generalFromSymmetric(ar.RawSymmetric())
  120. case RawSymBander:
  121. amat = generalFromSymmetricBand(ar.RawSymBand())
  122. case RawTriangular:
  123. amat = generalFromTriangular(ar.RawTriangular())
  124. case RawVectorer:
  125. r, c := a.Dims()
  126. amat = generalFromVector(ar.RawVector(), r, c)
  127. }
  128. return t.checkOverlap(amat)
  129. }
  130. // generalFromTriangular returns a blas64.General with the backing
  131. // data and dimensions of a.
  132. func generalFromTriangular(a blas64.Triangular) blas64.General {
  133. return blas64.General{
  134. Rows: a.N,
  135. Cols: a.N,
  136. Stride: a.Stride,
  137. Data: a.Data,
  138. }
  139. }
  140. func (v *VecDense) checkOverlap(a blas64.Vector) bool {
  141. mat := v.mat
  142. if cap(mat.Data) == 0 || cap(a.Data) == 0 {
  143. return false
  144. }
  145. off := offset(mat.Data[:1], a.Data[:1])
  146. if off == 0 {
  147. // At least one element overlaps.
  148. if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) {
  149. panic(regionIdentity)
  150. }
  151. panic(regionOverlap)
  152. }
  153. if off > 0 && len(mat.Data) <= off {
  154. // We know v is completely before a.
  155. return false
  156. }
  157. if off < 0 && len(a.Data) <= -off {
  158. // We know v is completely after a.
  159. return false
  160. }
  161. if mat.Inc != a.Inc && mat.Inc != 1 && a.Inc != 1 {
  162. // Too hard, so assume the worst; if either
  163. // increment is one it will be caught below.
  164. panic(mismatchedStrides)
  165. }
  166. inc := min(mat.Inc, a.Inc)
  167. if inc == 1 || off&inc == 0 {
  168. panic(regionOverlap)
  169. }
  170. return false
  171. }
  172. // generalFromVector returns a blas64.General with the backing
  173. // data and dimensions of a.
  174. func generalFromVector(a blas64.Vector, r, c int) blas64.General {
  175. return blas64.General{
  176. Rows: r,
  177. Cols: c,
  178. Stride: a.Inc,
  179. Data: a.Data,
  180. }
  181. }
  182. func (s *SymBandDense) checkOverlap(a blas64.General) bool {
  183. return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a)
  184. }
  185. func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool {
  186. if s == a {
  187. return false
  188. }
  189. var amat blas64.General
  190. switch ar := a.(type) {
  191. default:
  192. return false
  193. case RawMatrixer:
  194. amat = ar.RawMatrix()
  195. case RawSymmetricer:
  196. amat = generalFromSymmetric(ar.RawSymmetric())
  197. case RawSymBander:
  198. amat = generalFromSymmetricBand(ar.RawSymBand())
  199. case RawTriangular:
  200. amat = generalFromTriangular(ar.RawTriangular())
  201. case RawVectorer:
  202. r, c := a.Dims()
  203. amat = generalFromVector(ar.RawVector(), r, c)
  204. }
  205. return s.checkOverlap(amat)
  206. }
  207. // generalFromSymmetricBand returns a blas64.General with the backing
  208. // data and dimensions of a.
  209. func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General {
  210. return blas64.General{
  211. Rows: a.N,
  212. Cols: a.K + 1,
  213. Data: a.Data,
  214. Stride: a.Stride,
  215. }
  216. }