level3cmplx128.go 40 KB


  1. // Copyright ©2019 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gonum
  5. import (
  6. "math/cmplx"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/internal/asm/c128"
  9. )
  10. var _ blas.Complex128Level3 = Implementation{}
  11. // Zgemm performs one of the matrix-matrix operations
  12. // C = alpha * op(A) * op(B) + beta * C
  13. // where op(X) is one of
  14. // op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ,
  15. // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
  16. // op(B) a k×n matrix and C an m×n matrix.
  17. func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
  18. switch tA {
  19. default:
  20. panic(badTranspose)
  21. case blas.NoTrans, blas.Trans, blas.ConjTrans:
  22. }
  23. switch tB {
  24. default:
  25. panic(badTranspose)
  26. case blas.NoTrans, blas.Trans, blas.ConjTrans:
  27. }
  28. switch {
  29. case m < 0:
  30. panic(mLT0)
  31. case n < 0:
  32. panic(nLT0)
  33. case k < 0:
  34. panic(kLT0)
  35. }
  36. rowA, colA := m, k
  37. if tA != blas.NoTrans {
  38. rowA, colA = k, m
  39. }
  40. if lda < max(1, colA) {
  41. panic(badLdA)
  42. }
  43. rowB, colB := k, n
  44. if tB != blas.NoTrans {
  45. rowB, colB = n, k
  46. }
  47. if ldb < max(1, colB) {
  48. panic(badLdB)
  49. }
  50. if ldc < max(1, n) {
  51. panic(badLdC)
  52. }
  53. // Quick return if possible.
  54. if m == 0 || n == 0 {
  55. return
  56. }
  57. // For zero matrix size the following slice length checks are trivially satisfied.
  58. if len(a) < (rowA-1)*lda+colA {
  59. panic(shortA)
  60. }
  61. if len(b) < (rowB-1)*ldb+colB {
  62. panic(shortB)
  63. }
  64. if len(c) < (m-1)*ldc+n {
  65. panic(shortC)
  66. }
  67. // Quick return if possible.
  68. if (alpha == 0 || k == 0) && beta == 1 {
  69. return
  70. }
  71. if alpha == 0 {
  72. if beta == 0 {
  73. for i := 0; i < m; i++ {
  74. for j := 0; j < n; j++ {
  75. c[i*ldc+j] = 0
  76. }
  77. }
  78. } else {
  79. for i := 0; i < m; i++ {
  80. for j := 0; j < n; j++ {
  81. c[i*ldc+j] *= beta
  82. }
  83. }
  84. }
  85. return
  86. }
  87. switch tA {
  88. case blas.NoTrans:
  89. switch tB {
  90. case blas.NoTrans:
  91. // Form C = alpha * A * B + beta * C.
  92. for i := 0; i < m; i++ {
  93. switch {
  94. case beta == 0:
  95. for j := 0; j < n; j++ {
  96. c[i*ldc+j] = 0
  97. }
  98. case beta != 1:
  99. for j := 0; j < n; j++ {
  100. c[i*ldc+j] *= beta
  101. }
  102. }
  103. for l := 0; l < k; l++ {
  104. tmp := alpha * a[i*lda+l]
  105. for j := 0; j < n; j++ {
  106. c[i*ldc+j] += tmp * b[l*ldb+j]
  107. }
  108. }
  109. }
  110. case blas.Trans:
  111. // Form C = alpha * A * Bᵀ + beta * C.
  112. for i := 0; i < m; i++ {
  113. switch {
  114. case beta == 0:
  115. for j := 0; j < n; j++ {
  116. c[i*ldc+j] = 0
  117. }
  118. case beta != 1:
  119. for j := 0; j < n; j++ {
  120. c[i*ldc+j] *= beta
  121. }
  122. }
  123. for l := 0; l < k; l++ {
  124. tmp := alpha * a[i*lda+l]
  125. for j := 0; j < n; j++ {
  126. c[i*ldc+j] += tmp * b[j*ldb+l]
  127. }
  128. }
  129. }
  130. case blas.ConjTrans:
  131. // Form C = alpha * A * Bᴴ + beta * C.
  132. for i := 0; i < m; i++ {
  133. switch {
  134. case beta == 0:
  135. for j := 0; j < n; j++ {
  136. c[i*ldc+j] = 0
  137. }
  138. case beta != 1:
  139. for j := 0; j < n; j++ {
  140. c[i*ldc+j] *= beta
  141. }
  142. }
  143. for l := 0; l < k; l++ {
  144. tmp := alpha * a[i*lda+l]
  145. for j := 0; j < n; j++ {
  146. c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l])
  147. }
  148. }
  149. }
  150. }
  151. case blas.Trans:
  152. switch tB {
  153. case blas.NoTrans:
  154. // Form C = alpha * Aᵀ * B + beta * C.
  155. for i := 0; i < m; i++ {
  156. for j := 0; j < n; j++ {
  157. var tmp complex128
  158. for l := 0; l < k; l++ {
  159. tmp += a[l*lda+i] * b[l*ldb+j]
  160. }
  161. if beta == 0 {
  162. c[i*ldc+j] = alpha * tmp
  163. } else {
  164. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  165. }
  166. }
  167. }
  168. case blas.Trans:
  169. // Form C = alpha * Aᵀ * Bᵀ + beta * C.
  170. for i := 0; i < m; i++ {
  171. for j := 0; j < n; j++ {
  172. var tmp complex128
  173. for l := 0; l < k; l++ {
  174. tmp += a[l*lda+i] * b[j*ldb+l]
  175. }
  176. if beta == 0 {
  177. c[i*ldc+j] = alpha * tmp
  178. } else {
  179. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  180. }
  181. }
  182. }
  183. case blas.ConjTrans:
  184. // Form C = alpha * Aᵀ * Bᴴ + beta * C.
  185. for i := 0; i < m; i++ {
  186. for j := 0; j < n; j++ {
  187. var tmp complex128
  188. for l := 0; l < k; l++ {
  189. tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
  190. }
  191. if beta == 0 {
  192. c[i*ldc+j] = alpha * tmp
  193. } else {
  194. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  195. }
  196. }
  197. }
  198. }
  199. case blas.ConjTrans:
  200. switch tB {
  201. case blas.NoTrans:
  202. // Form C = alpha * Aᴴ * B + beta * C.
  203. for i := 0; i < m; i++ {
  204. for j := 0; j < n; j++ {
  205. var tmp complex128
  206. for l := 0; l < k; l++ {
  207. tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
  208. }
  209. if beta == 0 {
  210. c[i*ldc+j] = alpha * tmp
  211. } else {
  212. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  213. }
  214. }
  215. }
  216. case blas.Trans:
  217. // Form C = alpha * Aᴴ * Bᵀ + beta * C.
  218. for i := 0; i < m; i++ {
  219. for j := 0; j < n; j++ {
  220. var tmp complex128
  221. for l := 0; l < k; l++ {
  222. tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
  223. }
  224. if beta == 0 {
  225. c[i*ldc+j] = alpha * tmp
  226. } else {
  227. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  228. }
  229. }
  230. }
  231. case blas.ConjTrans:
  232. // Form C = alpha * Aᴴ * Bᴴ + beta * C.
  233. for i := 0; i < m; i++ {
  234. for j := 0; j < n; j++ {
  235. var tmp complex128
  236. for l := 0; l < k; l++ {
  237. tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
  238. }
  239. if beta == 0 {
  240. c[i*ldc+j] = alpha * tmp
  241. } else {
  242. c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
  243. }
  244. }
  245. }
  246. }
  247. }
  248. }
  249. // Zhemm performs one of the matrix-matrix operations
  250. // C = alpha*A*B + beta*C if side == blas.Left
  251. // C = alpha*B*A + beta*C if side == blas.Right
  252. // where alpha and beta are scalars, A is an m×m or n×n hermitian matrix and B
  253. // and C are m×n matrices. The imaginary parts of the diagonal elements of A are
  254. // assumed to be zero.
  255. func (Implementation) Zhemm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
  256. na := m
  257. if side == blas.Right {
  258. na = n
  259. }
  260. switch {
  261. case side != blas.Left && side != blas.Right:
  262. panic(badSide)
  263. case uplo != blas.Lower && uplo != blas.Upper:
  264. panic(badUplo)
  265. case m < 0:
  266. panic(mLT0)
  267. case n < 0:
  268. panic(nLT0)
  269. case lda < max(1, na):
  270. panic(badLdA)
  271. case ldb < max(1, n):
  272. panic(badLdB)
  273. case ldc < max(1, n):
  274. panic(badLdC)
  275. }
  276. // Quick return if possible.
  277. if m == 0 || n == 0 {
  278. return
  279. }
  280. // For zero matrix size the following slice length checks are trivially satisfied.
  281. if len(a) < lda*(na-1)+na {
  282. panic(shortA)
  283. }
  284. if len(b) < ldb*(m-1)+n {
  285. panic(shortB)
  286. }
  287. if len(c) < ldc*(m-1)+n {
  288. panic(shortC)
  289. }
  290. // Quick return if possible.
  291. if alpha == 0 && beta == 1 {
  292. return
  293. }
  294. if alpha == 0 {
  295. if beta == 0 {
  296. for i := 0; i < m; i++ {
  297. ci := c[i*ldc : i*ldc+n]
  298. for j := range ci {
  299. ci[j] = 0
  300. }
  301. }
  302. } else {
  303. for i := 0; i < m; i++ {
  304. ci := c[i*ldc : i*ldc+n]
  305. c128.ScalUnitary(beta, ci)
  306. }
  307. }
  308. return
  309. }
  310. if side == blas.Left {
  311. // Form C = alpha*A*B + beta*C.
  312. for i := 0; i < m; i++ {
  313. atmp := alpha * complex(real(a[i*lda+i]), 0)
  314. bi := b[i*ldb : i*ldb+n]
  315. ci := c[i*ldc : i*ldc+n]
  316. if beta == 0 {
  317. for j, bij := range bi {
  318. ci[j] = atmp * bij
  319. }
  320. } else {
  321. for j, bij := range bi {
  322. ci[j] = atmp*bij + beta*ci[j]
  323. }
  324. }
  325. if uplo == blas.Upper {
  326. for k := 0; k < i; k++ {
  327. atmp = alpha * cmplx.Conj(a[k*lda+i])
  328. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  329. }
  330. for k := i + 1; k < m; k++ {
  331. atmp = alpha * a[i*lda+k]
  332. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  333. }
  334. } else {
  335. for k := 0; k < i; k++ {
  336. atmp = alpha * a[i*lda+k]
  337. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  338. }
  339. for k := i + 1; k < m; k++ {
  340. atmp = alpha * cmplx.Conj(a[k*lda+i])
  341. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  342. }
  343. }
  344. }
  345. } else {
  346. // Form C = alpha*B*A + beta*C.
  347. if uplo == blas.Upper {
  348. for i := 0; i < m; i++ {
  349. for j := n - 1; j >= 0; j-- {
  350. abij := alpha * b[i*ldb+j]
  351. aj := a[j*lda+j+1 : j*lda+n]
  352. bi := b[i*ldb+j+1 : i*ldb+n]
  353. ci := c[i*ldc+j+1 : i*ldc+n]
  354. var tmp complex128
  355. for k, ajk := range aj {
  356. ci[k] += abij * ajk
  357. tmp += bi[k] * cmplx.Conj(ajk)
  358. }
  359. ajj := complex(real(a[j*lda+j]), 0)
  360. if beta == 0 {
  361. c[i*ldc+j] = abij*ajj + alpha*tmp
  362. } else {
  363. c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
  364. }
  365. }
  366. }
  367. } else {
  368. for i := 0; i < m; i++ {
  369. for j := 0; j < n; j++ {
  370. abij := alpha * b[i*ldb+j]
  371. aj := a[j*lda : j*lda+j]
  372. bi := b[i*ldb : i*ldb+j]
  373. ci := c[i*ldc : i*ldc+j]
  374. var tmp complex128
  375. for k, ajk := range aj {
  376. ci[k] += abij * ajk
  377. tmp += bi[k] * cmplx.Conj(ajk)
  378. }
  379. ajj := complex(real(a[j*lda+j]), 0)
  380. if beta == 0 {
  381. c[i*ldc+j] = abij*ajj + alpha*tmp
  382. } else {
  383. c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
  384. }
  385. }
  386. }
  387. }
  388. }
  389. }
  390. // Zherk performs one of the hermitian rank-k operations
  391. // C = alpha*A*Aᴴ + beta*C if trans == blas.NoTrans
  392. // C = alpha*Aᴴ*A + beta*C if trans == blas.ConjTrans
  393. // where alpha and beta are real scalars, C is an n×n hermitian matrix and A is
  394. // an n×k matrix in the first case and a k×n matrix in the second case.
  395. //
  396. // The imaginary parts of the diagonal elements of C are assumed to be zero, and
  397. // on return they will be set to zero.
  398. func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) {
  399. var rowA, colA int
  400. switch trans {
  401. default:
  402. panic(badTranspose)
  403. case blas.NoTrans:
  404. rowA, colA = n, k
  405. case blas.ConjTrans:
  406. rowA, colA = k, n
  407. }
  408. switch {
  409. case uplo != blas.Lower && uplo != blas.Upper:
  410. panic(badUplo)
  411. case n < 0:
  412. panic(nLT0)
  413. case k < 0:
  414. panic(kLT0)
  415. case lda < max(1, colA):
  416. panic(badLdA)
  417. case ldc < max(1, n):
  418. panic(badLdC)
  419. }
  420. // Quick return if possible.
  421. if n == 0 {
  422. return
  423. }
  424. // For zero matrix size the following slice length checks are trivially satisfied.
  425. if len(a) < (rowA-1)*lda+colA {
  426. panic(shortA)
  427. }
  428. if len(c) < (n-1)*ldc+n {
  429. panic(shortC)
  430. }
  431. // Quick return if possible.
  432. if (alpha == 0 || k == 0) && beta == 1 {
  433. return
  434. }
  435. if alpha == 0 {
  436. if uplo == blas.Upper {
  437. if beta == 0 {
  438. for i := 0; i < n; i++ {
  439. ci := c[i*ldc+i : i*ldc+n]
  440. for j := range ci {
  441. ci[j] = 0
  442. }
  443. }
  444. } else {
  445. for i := 0; i < n; i++ {
  446. ci := c[i*ldc+i : i*ldc+n]
  447. ci[0] = complex(beta*real(ci[0]), 0)
  448. if i != n-1 {
  449. c128.DscalUnitary(beta, ci[1:])
  450. }
  451. }
  452. }
  453. } else {
  454. if beta == 0 {
  455. for i := 0; i < n; i++ {
  456. ci := c[i*ldc : i*ldc+i+1]
  457. for j := range ci {
  458. ci[j] = 0
  459. }
  460. }
  461. } else {
  462. for i := 0; i < n; i++ {
  463. ci := c[i*ldc : i*ldc+i+1]
  464. if i != 0 {
  465. c128.DscalUnitary(beta, ci[:i])
  466. }
  467. ci[i] = complex(beta*real(ci[i]), 0)
  468. }
  469. }
  470. }
  471. return
  472. }
  473. calpha := complex(alpha, 0)
  474. if trans == blas.NoTrans {
  475. // Form C = alpha*A*Aᴴ + beta*C.
  476. cbeta := complex(beta, 0)
  477. if uplo == blas.Upper {
  478. for i := 0; i < n; i++ {
  479. ci := c[i*ldc+i : i*ldc+n]
  480. ai := a[i*lda : i*lda+k]
  481. switch {
  482. case beta == 0:
  483. // Handle the i-th diagonal element of C.
  484. ci[0] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
  485. // Handle the remaining elements on the i-th row of C.
  486. for jc := range ci[1:] {
  487. j := i + 1 + jc
  488. ci[jc+1] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
  489. }
  490. case beta != 1:
  491. cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[0]
  492. ci[0] = complex(real(cii), 0)
  493. for jc, cij := range ci[1:] {
  494. j := i + 1 + jc
  495. ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
  496. }
  497. default:
  498. cii := calpha*c128.DotcUnitary(ai, ai) + ci[0]
  499. ci[0] = complex(real(cii), 0)
  500. for jc, cij := range ci[1:] {
  501. j := i + 1 + jc
  502. ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
  503. }
  504. }
  505. }
  506. } else {
  507. for i := 0; i < n; i++ {
  508. ci := c[i*ldc : i*ldc+i+1]
  509. ai := a[i*lda : i*lda+k]
  510. switch {
  511. case beta == 0:
  512. // Handle the first i-1 elements on the i-th row of C.
  513. for j := range ci[:i] {
  514. ci[j] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
  515. }
  516. // Handle the i-th diagonal element of C.
  517. ci[i] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
  518. case beta != 1:
  519. for j, cij := range ci[:i] {
  520. ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
  521. }
  522. cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[i]
  523. ci[i] = complex(real(cii), 0)
  524. default:
  525. for j, cij := range ci[:i] {
  526. ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
  527. }
  528. cii := calpha*c128.DotcUnitary(ai, ai) + ci[i]
  529. ci[i] = complex(real(cii), 0)
  530. }
  531. }
  532. }
  533. } else {
  534. // Form C = alpha*Aᴴ*A + beta*C.
  535. if uplo == blas.Upper {
  536. for i := 0; i < n; i++ {
  537. ci := c[i*ldc+i : i*ldc+n]
  538. switch {
  539. case beta == 0:
  540. for jc := range ci {
  541. ci[jc] = 0
  542. }
  543. case beta != 1:
  544. c128.DscalUnitary(beta, ci)
  545. ci[0] = complex(real(ci[0]), 0)
  546. default:
  547. ci[0] = complex(real(ci[0]), 0)
  548. }
  549. for j := 0; j < k; j++ {
  550. aji := cmplx.Conj(a[j*lda+i])
  551. if aji != 0 {
  552. c128.AxpyUnitary(calpha*aji, a[j*lda+i:j*lda+n], ci)
  553. }
  554. }
  555. c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
  556. }
  557. } else {
  558. for i := 0; i < n; i++ {
  559. ci := c[i*ldc : i*ldc+i+1]
  560. switch {
  561. case beta == 0:
  562. for j := range ci {
  563. ci[j] = 0
  564. }
  565. case beta != 1:
  566. c128.DscalUnitary(beta, ci)
  567. ci[i] = complex(real(ci[i]), 0)
  568. default:
  569. ci[i] = complex(real(ci[i]), 0)
  570. }
  571. for j := 0; j < k; j++ {
  572. aji := cmplx.Conj(a[j*lda+i])
  573. if aji != 0 {
  574. c128.AxpyUnitary(calpha*aji, a[j*lda:j*lda+i+1], ci)
  575. }
  576. }
  577. c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
  578. }
  579. }
  580. }
  581. }
  582. // Zher2k performs one of the hermitian rank-2k operations
  583. // C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C if trans == blas.NoTrans
  584. // C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C if trans == blas.ConjTrans
  585. // where alpha and beta are scalars with beta real, C is an n×n hermitian matrix
  586. // and A and B are n×k matrices in the first case and k×n matrices in the second case.
  587. //
  588. // The imaginary parts of the diagonal elements of C are assumed to be zero, and
  589. // on return they will be set to zero.
  590. func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) {
  591. var row, col int
  592. switch trans {
  593. default:
  594. panic(badTranspose)
  595. case blas.NoTrans:
  596. row, col = n, k
  597. case blas.ConjTrans:
  598. row, col = k, n
  599. }
  600. switch {
  601. case uplo != blas.Lower && uplo != blas.Upper:
  602. panic(badUplo)
  603. case n < 0:
  604. panic(nLT0)
  605. case k < 0:
  606. panic(kLT0)
  607. case lda < max(1, col):
  608. panic(badLdA)
  609. case ldb < max(1, col):
  610. panic(badLdB)
  611. case ldc < max(1, n):
  612. panic(badLdC)
  613. }
  614. // Quick return if possible.
  615. if n == 0 {
  616. return
  617. }
  618. // For zero matrix size the following slice length checks are trivially satisfied.
  619. if len(a) < (row-1)*lda+col {
  620. panic(shortA)
  621. }
  622. if len(b) < (row-1)*ldb+col {
  623. panic(shortB)
  624. }
  625. if len(c) < (n-1)*ldc+n {
  626. panic(shortC)
  627. }
  628. // Quick return if possible.
  629. if (alpha == 0 || k == 0) && beta == 1 {
  630. return
  631. }
  632. if alpha == 0 {
  633. if uplo == blas.Upper {
  634. if beta == 0 {
  635. for i := 0; i < n; i++ {
  636. ci := c[i*ldc+i : i*ldc+n]
  637. for j := range ci {
  638. ci[j] = 0
  639. }
  640. }
  641. } else {
  642. for i := 0; i < n; i++ {
  643. ci := c[i*ldc+i : i*ldc+n]
  644. ci[0] = complex(beta*real(ci[0]), 0)
  645. if i != n-1 {
  646. c128.DscalUnitary(beta, ci[1:])
  647. }
  648. }
  649. }
  650. } else {
  651. if beta == 0 {
  652. for i := 0; i < n; i++ {
  653. ci := c[i*ldc : i*ldc+i+1]
  654. for j := range ci {
  655. ci[j] = 0
  656. }
  657. }
  658. } else {
  659. for i := 0; i < n; i++ {
  660. ci := c[i*ldc : i*ldc+i+1]
  661. if i != 0 {
  662. c128.DscalUnitary(beta, ci[:i])
  663. }
  664. ci[i] = complex(beta*real(ci[i]), 0)
  665. }
  666. }
  667. }
  668. return
  669. }
  670. conjalpha := cmplx.Conj(alpha)
  671. cbeta := complex(beta, 0)
  672. if trans == blas.NoTrans {
  673. // Form C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C.
  674. if uplo == blas.Upper {
  675. for i := 0; i < n; i++ {
  676. ci := c[i*ldc+i+1 : i*ldc+n]
  677. ai := a[i*lda : i*lda+k]
  678. bi := b[i*ldb : i*ldb+k]
  679. if beta == 0 {
  680. cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
  681. c[i*ldc+i] = complex(real(cii), 0)
  682. for jc := range ci {
  683. j := i + 1 + jc
  684. ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
  685. }
  686. } else {
  687. cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
  688. c[i*ldc+i] = complex(real(cii), 0)
  689. for jc, cij := range ci {
  690. j := i + 1 + jc
  691. ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
  692. }
  693. }
  694. }
  695. } else {
  696. for i := 0; i < n; i++ {
  697. ci := c[i*ldc : i*ldc+i]
  698. ai := a[i*lda : i*lda+k]
  699. bi := b[i*ldb : i*ldb+k]
  700. if beta == 0 {
  701. for j := range ci {
  702. ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
  703. }
  704. cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
  705. c[i*ldc+i] = complex(real(cii), 0)
  706. } else {
  707. for j, cij := range ci {
  708. ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
  709. }
  710. cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
  711. c[i*ldc+i] = complex(real(cii), 0)
  712. }
  713. }
  714. }
  715. } else {
  716. // Form C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C.
  717. if uplo == blas.Upper {
  718. for i := 0; i < n; i++ {
  719. ci := c[i*ldc+i : i*ldc+n]
  720. switch {
  721. case beta == 0:
  722. for jc := range ci {
  723. ci[jc] = 0
  724. }
  725. case beta != 1:
  726. c128.DscalUnitary(beta, ci)
  727. ci[0] = complex(real(ci[0]), 0)
  728. default:
  729. ci[0] = complex(real(ci[0]), 0)
  730. }
  731. for j := 0; j < k; j++ {
  732. aji := a[j*lda+i]
  733. bji := b[j*ldb+i]
  734. if aji != 0 {
  735. c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci)
  736. }
  737. if bji != 0 {
  738. c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci)
  739. }
  740. }
  741. ci[0] = complex(real(ci[0]), 0)
  742. }
  743. } else {
  744. for i := 0; i < n; i++ {
  745. ci := c[i*ldc : i*ldc+i+1]
  746. switch {
  747. case beta == 0:
  748. for j := range ci {
  749. ci[j] = 0
  750. }
  751. case beta != 1:
  752. c128.DscalUnitary(beta, ci)
  753. ci[i] = complex(real(ci[i]), 0)
  754. default:
  755. ci[i] = complex(real(ci[i]), 0)
  756. }
  757. for j := 0; j < k; j++ {
  758. aji := a[j*lda+i]
  759. bji := b[j*ldb+i]
  760. if aji != 0 {
  761. c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci)
  762. }
  763. if bji != 0 {
  764. c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci)
  765. }
  766. }
  767. ci[i] = complex(real(ci[i]), 0)
  768. }
  769. }
  770. }
  771. }
  772. // Zsymm performs one of the matrix-matrix operations
  773. // C = alpha*A*B + beta*C if side == blas.Left
  774. // C = alpha*B*A + beta*C if side == blas.Right
  775. // where alpha and beta are scalars, A is an m×m or n×n symmetric matrix and B
  776. // and C are m×n matrices.
  777. func (Implementation) Zsymm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
  778. na := m
  779. if side == blas.Right {
  780. na = n
  781. }
  782. switch {
  783. case side != blas.Left && side != blas.Right:
  784. panic(badSide)
  785. case uplo != blas.Lower && uplo != blas.Upper:
  786. panic(badUplo)
  787. case m < 0:
  788. panic(mLT0)
  789. case n < 0:
  790. panic(nLT0)
  791. case lda < max(1, na):
  792. panic(badLdA)
  793. case ldb < max(1, n):
  794. panic(badLdB)
  795. case ldc < max(1, n):
  796. panic(badLdC)
  797. }
  798. // Quick return if possible.
  799. if m == 0 || n == 0 {
  800. return
  801. }
  802. // For zero matrix size the following slice length checks are trivially satisfied.
  803. if len(a) < lda*(na-1)+na {
  804. panic(shortA)
  805. }
  806. if len(b) < ldb*(m-1)+n {
  807. panic(shortB)
  808. }
  809. if len(c) < ldc*(m-1)+n {
  810. panic(shortC)
  811. }
  812. // Quick return if possible.
  813. if alpha == 0 && beta == 1 {
  814. return
  815. }
  816. if alpha == 0 {
  817. if beta == 0 {
  818. for i := 0; i < m; i++ {
  819. ci := c[i*ldc : i*ldc+n]
  820. for j := range ci {
  821. ci[j] = 0
  822. }
  823. }
  824. } else {
  825. for i := 0; i < m; i++ {
  826. ci := c[i*ldc : i*ldc+n]
  827. c128.ScalUnitary(beta, ci)
  828. }
  829. }
  830. return
  831. }
  832. if side == blas.Left {
  833. // Form C = alpha*A*B + beta*C.
  834. for i := 0; i < m; i++ {
  835. atmp := alpha * a[i*lda+i]
  836. bi := b[i*ldb : i*ldb+n]
  837. ci := c[i*ldc : i*ldc+n]
  838. if beta == 0 {
  839. for j, bij := range bi {
  840. ci[j] = atmp * bij
  841. }
  842. } else {
  843. for j, bij := range bi {
  844. ci[j] = atmp*bij + beta*ci[j]
  845. }
  846. }
  847. if uplo == blas.Upper {
  848. for k := 0; k < i; k++ {
  849. atmp = alpha * a[k*lda+i]
  850. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  851. }
  852. for k := i + 1; k < m; k++ {
  853. atmp = alpha * a[i*lda+k]
  854. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  855. }
  856. } else {
  857. for k := 0; k < i; k++ {
  858. atmp = alpha * a[i*lda+k]
  859. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  860. }
  861. for k := i + 1; k < m; k++ {
  862. atmp = alpha * a[k*lda+i]
  863. c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
  864. }
  865. }
  866. }
  867. } else {
  868. // Form C = alpha*B*A + beta*C.
  869. if uplo == blas.Upper {
  870. for i := 0; i < m; i++ {
  871. for j := n - 1; j >= 0; j-- {
  872. abij := alpha * b[i*ldb+j]
  873. aj := a[j*lda+j+1 : j*lda+n]
  874. bi := b[i*ldb+j+1 : i*ldb+n]
  875. ci := c[i*ldc+j+1 : i*ldc+n]
  876. var tmp complex128
  877. for k, ajk := range aj {
  878. ci[k] += abij * ajk
  879. tmp += bi[k] * ajk
  880. }
  881. if beta == 0 {
  882. c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
  883. } else {
  884. c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
  885. }
  886. }
  887. }
  888. } else {
  889. for i := 0; i < m; i++ {
  890. for j := 0; j < n; j++ {
  891. abij := alpha * b[i*ldb+j]
  892. aj := a[j*lda : j*lda+j]
  893. bi := b[i*ldb : i*ldb+j]
  894. ci := c[i*ldc : i*ldc+j]
  895. var tmp complex128
  896. for k, ajk := range aj {
  897. ci[k] += abij * ajk
  898. tmp += bi[k] * ajk
  899. }
  900. if beta == 0 {
  901. c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
  902. } else {
  903. c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
  904. }
  905. }
  906. }
  907. }
  908. }
  909. }
  910. // Zsyrk performs one of the symmetric rank-k operations
  911. // C = alpha*A*Aᵀ + beta*C if trans == blas.NoTrans
  912. // C = alpha*Aᵀ*A + beta*C if trans == blas.Trans
  913. // where alpha and beta are scalars, C is an n×n symmetric matrix and A is
  914. // an n×k matrix in the first case and a k×n matrix in the second case.
  915. func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) {
  916. var rowA, colA int
  917. switch trans {
  918. default:
  919. panic(badTranspose)
  920. case blas.NoTrans:
  921. rowA, colA = n, k
  922. case blas.Trans:
  923. rowA, colA = k, n
  924. }
  925. switch {
  926. case uplo != blas.Lower && uplo != blas.Upper:
  927. panic(badUplo)
  928. case n < 0:
  929. panic(nLT0)
  930. case k < 0:
  931. panic(kLT0)
  932. case lda < max(1, colA):
  933. panic(badLdA)
  934. case ldc < max(1, n):
  935. panic(badLdC)
  936. }
  937. // Quick return if possible.
  938. if n == 0 {
  939. return
  940. }
  941. // For zero matrix size the following slice length checks are trivially satisfied.
  942. if len(a) < (rowA-1)*lda+colA {
  943. panic(shortA)
  944. }
  945. if len(c) < (n-1)*ldc+n {
  946. panic(shortC)
  947. }
  948. // Quick return if possible.
  949. if (alpha == 0 || k == 0) && beta == 1 {
  950. return
  951. }
  952. if alpha == 0 {
  953. if uplo == blas.Upper {
  954. if beta == 0 {
  955. for i := 0; i < n; i++ {
  956. ci := c[i*ldc+i : i*ldc+n]
  957. for j := range ci {
  958. ci[j] = 0
  959. }
  960. }
  961. } else {
  962. for i := 0; i < n; i++ {
  963. ci := c[i*ldc+i : i*ldc+n]
  964. c128.ScalUnitary(beta, ci)
  965. }
  966. }
  967. } else {
  968. if beta == 0 {
  969. for i := 0; i < n; i++ {
  970. ci := c[i*ldc : i*ldc+i+1]
  971. for j := range ci {
  972. ci[j] = 0
  973. }
  974. }
  975. } else {
  976. for i := 0; i < n; i++ {
  977. ci := c[i*ldc : i*ldc+i+1]
  978. c128.ScalUnitary(beta, ci)
  979. }
  980. }
  981. }
  982. return
  983. }
  984. if trans == blas.NoTrans {
  985. // Form C = alpha*A*Aᵀ + beta*C.
  986. if uplo == blas.Upper {
  987. for i := 0; i < n; i++ {
  988. ci := c[i*ldc+i : i*ldc+n]
  989. ai := a[i*lda : i*lda+k]
  990. if beta == 0 {
  991. for jc := range ci {
  992. j := i + jc
  993. ci[jc] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
  994. }
  995. } else {
  996. for jc, cij := range ci {
  997. j := i + jc
  998. ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
  999. }
  1000. }
  1001. }
  1002. } else {
  1003. for i := 0; i < n; i++ {
  1004. ci := c[i*ldc : i*ldc+i+1]
  1005. ai := a[i*lda : i*lda+k]
  1006. if beta == 0 {
  1007. for j := range ci {
  1008. ci[j] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
  1009. }
  1010. } else {
  1011. for j, cij := range ci {
  1012. ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
  1013. }
  1014. }
  1015. }
  1016. }
  1017. } else {
  1018. // Form C = alpha*Aᵀ*A + beta*C.
  1019. if uplo == blas.Upper {
  1020. for i := 0; i < n; i++ {
  1021. ci := c[i*ldc+i : i*ldc+n]
  1022. switch {
  1023. case beta == 0:
  1024. for jc := range ci {
  1025. ci[jc] = 0
  1026. }
  1027. case beta != 1:
  1028. for jc := range ci {
  1029. ci[jc] *= beta
  1030. }
  1031. }
  1032. for j := 0; j < k; j++ {
  1033. aji := a[j*lda+i]
  1034. if aji != 0 {
  1035. c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci)
  1036. }
  1037. }
  1038. }
  1039. } else {
  1040. for i := 0; i < n; i++ {
  1041. ci := c[i*ldc : i*ldc+i+1]
  1042. switch {
  1043. case beta == 0:
  1044. for j := range ci {
  1045. ci[j] = 0
  1046. }
  1047. case beta != 1:
  1048. for j := range ci {
  1049. ci[j] *= beta
  1050. }
  1051. }
  1052. for j := 0; j < k; j++ {
  1053. aji := a[j*lda+i]
  1054. if aji != 0 {
  1055. c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci)
  1056. }
  1057. }
  1058. }
  1059. }
  1060. }
  1061. }
  1062. // Zsyr2k performs one of the symmetric rank-2k operations
  1063. // C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C if trans == blas.NoTrans
  1064. // C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C if trans == blas.Trans
  1065. // where alpha and beta are scalars, C is an n×n symmetric matrix and A and B
  1066. // are n×k matrices in the first case and k×n matrices in the second case.
  1067. func (Implementation) Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
  1068. var row, col int
  1069. switch trans {
  1070. default:
  1071. panic(badTranspose)
  1072. case blas.NoTrans:
  1073. row, col = n, k
  1074. case blas.Trans:
  1075. row, col = k, n
  1076. }
  1077. switch {
  1078. case uplo != blas.Lower && uplo != blas.Upper:
  1079. panic(badUplo)
  1080. case n < 0:
  1081. panic(nLT0)
  1082. case k < 0:
  1083. panic(kLT0)
  1084. case lda < max(1, col):
  1085. panic(badLdA)
  1086. case ldb < max(1, col):
  1087. panic(badLdB)
  1088. case ldc < max(1, n):
  1089. panic(badLdC)
  1090. }
  1091. // Quick return if possible.
  1092. if n == 0 {
  1093. return
  1094. }
  1095. // For zero matrix size the following slice length checks are trivially satisfied.
  1096. if len(a) < (row-1)*lda+col {
  1097. panic(shortA)
  1098. }
  1099. if len(b) < (row-1)*ldb+col {
  1100. panic(shortB)
  1101. }
  1102. if len(c) < (n-1)*ldc+n {
  1103. panic(shortC)
  1104. }
  1105. // Quick return if possible.
  1106. if (alpha == 0 || k == 0) && beta == 1 {
  1107. return
  1108. }
  1109. if alpha == 0 {
  1110. if uplo == blas.Upper {
  1111. if beta == 0 {
  1112. for i := 0; i < n; i++ {
  1113. ci := c[i*ldc+i : i*ldc+n]
  1114. for j := range ci {
  1115. ci[j] = 0
  1116. }
  1117. }
  1118. } else {
  1119. for i := 0; i < n; i++ {
  1120. ci := c[i*ldc+i : i*ldc+n]
  1121. c128.ScalUnitary(beta, ci)
  1122. }
  1123. }
  1124. } else {
  1125. if beta == 0 {
  1126. for i := 0; i < n; i++ {
  1127. ci := c[i*ldc : i*ldc+i+1]
  1128. for j := range ci {
  1129. ci[j] = 0
  1130. }
  1131. }
  1132. } else {
  1133. for i := 0; i < n; i++ {
  1134. ci := c[i*ldc : i*ldc+i+1]
  1135. c128.ScalUnitary(beta, ci)
  1136. }
  1137. }
  1138. }
  1139. return
  1140. }
  1141. if trans == blas.NoTrans {
  1142. // Form C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C.
  1143. if uplo == blas.Upper {
  1144. for i := 0; i < n; i++ {
  1145. ci := c[i*ldc+i : i*ldc+n]
  1146. ai := a[i*lda : i*lda+k]
  1147. bi := b[i*ldb : i*ldb+k]
  1148. if beta == 0 {
  1149. for jc := range ci {
  1150. j := i + jc
  1151. ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
  1152. }
  1153. } else {
  1154. for jc, cij := range ci {
  1155. j := i + jc
  1156. ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
  1157. }
  1158. }
  1159. }
  1160. } else {
  1161. for i := 0; i < n; i++ {
  1162. ci := c[i*ldc : i*ldc+i+1]
  1163. ai := a[i*lda : i*lda+k]
  1164. bi := b[i*ldb : i*ldb+k]
  1165. if beta == 0 {
  1166. for j := range ci {
  1167. ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
  1168. }
  1169. } else {
  1170. for j, cij := range ci {
  1171. ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
  1172. }
  1173. }
  1174. }
  1175. }
  1176. } else {
  1177. // Form C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C.
  1178. if uplo == blas.Upper {
  1179. for i := 0; i < n; i++ {
  1180. ci := c[i*ldc+i : i*ldc+n]
  1181. switch {
  1182. case beta == 0:
  1183. for jc := range ci {
  1184. ci[jc] = 0
  1185. }
  1186. case beta != 1:
  1187. for jc := range ci {
  1188. ci[jc] *= beta
  1189. }
  1190. }
  1191. for j := 0; j < k; j++ {
  1192. aji := a[j*lda+i]
  1193. bji := b[j*ldb+i]
  1194. if aji != 0 {
  1195. c128.AxpyUnitary(alpha*aji, b[j*ldb+i:j*ldb+n], ci)
  1196. }
  1197. if bji != 0 {
  1198. c128.AxpyUnitary(alpha*bji, a[j*lda+i:j*lda+n], ci)
  1199. }
  1200. }
  1201. }
  1202. } else {
  1203. for i := 0; i < n; i++ {
  1204. ci := c[i*ldc : i*ldc+i+1]
  1205. switch {
  1206. case beta == 0:
  1207. for j := range ci {
  1208. ci[j] = 0
  1209. }
  1210. case beta != 1:
  1211. for j := range ci {
  1212. ci[j] *= beta
  1213. }
  1214. }
  1215. for j := 0; j < k; j++ {
  1216. aji := a[j*lda+i]
  1217. bji := b[j*ldb+i]
  1218. if aji != 0 {
  1219. c128.AxpyUnitary(alpha*aji, b[j*ldb:j*ldb+i+1], ci)
  1220. }
  1221. if bji != 0 {
  1222. c128.AxpyUnitary(alpha*bji, a[j*lda:j*lda+i+1], ci)
  1223. }
  1224. }
  1225. }
  1226. }
  1227. }
  1228. }
  1229. // Ztrmm performs one of the matrix-matrix operations
  1230. // B = alpha * op(A) * B if side == blas.Left,
  1231. // B = alpha * B * op(A) if side == blas.Right,
  1232. // where alpha is a scalar, B is an m×n matrix, A is a unit, or non-unit,
  1233. // upper or lower triangular matrix and op(A) is one of
  1234. // op(A) = A if trans == blas.NoTrans,
  1235. // op(A) = Aᵀ if trans == blas.Trans,
  1236. // op(A) = Aᴴ if trans == blas.ConjTrans.
  1237. func (Implementation) Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
  1238. na := m
  1239. if side == blas.Right {
  1240. na = n
  1241. }
  1242. switch {
  1243. case side != blas.Left && side != blas.Right:
  1244. panic(badSide)
  1245. case uplo != blas.Lower && uplo != blas.Upper:
  1246. panic(badUplo)
  1247. case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
  1248. panic(badTranspose)
  1249. case diag != blas.Unit && diag != blas.NonUnit:
  1250. panic(badDiag)
  1251. case m < 0:
  1252. panic(mLT0)
  1253. case n < 0:
  1254. panic(nLT0)
  1255. case lda < max(1, na):
  1256. panic(badLdA)
  1257. case ldb < max(1, n):
  1258. panic(badLdB)
  1259. }
  1260. // Quick return if possible.
  1261. if m == 0 || n == 0 {
  1262. return
  1263. }
  1264. // For zero matrix size the following slice length checks are trivially satisfied.
  1265. if len(a) < (na-1)*lda+na {
  1266. panic(shortA)
  1267. }
  1268. if len(b) < (m-1)*ldb+n {
  1269. panic(shortB)
  1270. }
  1271. // Quick return if possible.
  1272. if alpha == 0 {
  1273. for i := 0; i < m; i++ {
  1274. bi := b[i*ldb : i*ldb+n]
  1275. for j := range bi {
  1276. bi[j] = 0
  1277. }
  1278. }
  1279. return
  1280. }
  1281. noConj := trans != blas.ConjTrans
  1282. noUnit := diag == blas.NonUnit
  1283. if side == blas.Left {
  1284. if trans == blas.NoTrans {
  1285. // Form B = alpha*A*B.
  1286. if uplo == blas.Upper {
  1287. for i := 0; i < m; i++ {
  1288. aii := alpha
  1289. if noUnit {
  1290. aii *= a[i*lda+i]
  1291. }
  1292. bi := b[i*ldb : i*ldb+n]
  1293. for j := range bi {
  1294. bi[j] *= aii
  1295. }
  1296. for ja, aij := range a[i*lda+i+1 : i*lda+m] {
  1297. j := ja + i + 1
  1298. if aij != 0 {
  1299. c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
  1300. }
  1301. }
  1302. }
  1303. } else {
  1304. for i := m - 1; i >= 0; i-- {
  1305. aii := alpha
  1306. if noUnit {
  1307. aii *= a[i*lda+i]
  1308. }
  1309. bi := b[i*ldb : i*ldb+n]
  1310. for j := range bi {
  1311. bi[j] *= aii
  1312. }
  1313. for j, aij := range a[i*lda : i*lda+i] {
  1314. if aij != 0 {
  1315. c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
  1316. }
  1317. }
  1318. }
  1319. }
  1320. } else {
  1321. // Form B = alpha*Aᵀ*B or B = alpha*Aᴴ*B.
  1322. if uplo == blas.Upper {
  1323. for k := m - 1; k >= 0; k-- {
  1324. bk := b[k*ldb : k*ldb+n]
  1325. for ja, ajk := range a[k*lda+k+1 : k*lda+m] {
  1326. if ajk == 0 {
  1327. continue
  1328. }
  1329. j := k + 1 + ja
  1330. if noConj {
  1331. c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
  1332. } else {
  1333. c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
  1334. }
  1335. }
  1336. akk := alpha
  1337. if noUnit {
  1338. if noConj {
  1339. akk *= a[k*lda+k]
  1340. } else {
  1341. akk *= cmplx.Conj(a[k*lda+k])
  1342. }
  1343. }
  1344. if akk != 1 {
  1345. c128.ScalUnitary(akk, bk)
  1346. }
  1347. }
  1348. } else {
  1349. for k := 0; k < m; k++ {
  1350. bk := b[k*ldb : k*ldb+n]
  1351. for j, ajk := range a[k*lda : k*lda+k] {
  1352. if ajk == 0 {
  1353. continue
  1354. }
  1355. if noConj {
  1356. c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
  1357. } else {
  1358. c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
  1359. }
  1360. }
  1361. akk := alpha
  1362. if noUnit {
  1363. if noConj {
  1364. akk *= a[k*lda+k]
  1365. } else {
  1366. akk *= cmplx.Conj(a[k*lda+k])
  1367. }
  1368. }
  1369. if akk != 1 {
  1370. c128.ScalUnitary(akk, bk)
  1371. }
  1372. }
  1373. }
  1374. }
  1375. } else {
  1376. if trans == blas.NoTrans {
  1377. // Form B = alpha*B*A.
  1378. if uplo == blas.Upper {
  1379. for i := 0; i < m; i++ {
  1380. bi := b[i*ldb : i*ldb+n]
  1381. for k := n - 1; k >= 0; k-- {
  1382. abik := alpha * bi[k]
  1383. if abik == 0 {
  1384. continue
  1385. }
  1386. bi[k] = abik
  1387. if noUnit {
  1388. bi[k] *= a[k*lda+k]
  1389. }
  1390. c128.AxpyUnitary(abik, a[k*lda+k+1:k*lda+n], bi[k+1:])
  1391. }
  1392. }
  1393. } else {
  1394. for i := 0; i < m; i++ {
  1395. bi := b[i*ldb : i*ldb+n]
  1396. for k := 0; k < n; k++ {
  1397. abik := alpha * bi[k]
  1398. if abik == 0 {
  1399. continue
  1400. }
  1401. bi[k] = abik
  1402. if noUnit {
  1403. bi[k] *= a[k*lda+k]
  1404. }
  1405. c128.AxpyUnitary(abik, a[k*lda:k*lda+k], bi[:k])
  1406. }
  1407. }
  1408. }
  1409. } else {
  1410. // Form B = alpha*B*Aᵀ or B = alpha*B*Aᴴ.
  1411. if uplo == blas.Upper {
  1412. for i := 0; i < m; i++ {
  1413. bi := b[i*ldb : i*ldb+n]
  1414. for j, bij := range bi {
  1415. if noConj {
  1416. if noUnit {
  1417. bij *= a[j*lda+j]
  1418. }
  1419. bij += c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
  1420. } else {
  1421. if noUnit {
  1422. bij *= cmplx.Conj(a[j*lda+j])
  1423. }
  1424. bij += c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
  1425. }
  1426. bi[j] = alpha * bij
  1427. }
  1428. }
  1429. } else {
  1430. for i := 0; i < m; i++ {
  1431. bi := b[i*ldb : i*ldb+n]
  1432. for j := n - 1; j >= 0; j-- {
  1433. bij := bi[j]
  1434. if noConj {
  1435. if noUnit {
  1436. bij *= a[j*lda+j]
  1437. }
  1438. bij += c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
  1439. } else {
  1440. if noUnit {
  1441. bij *= cmplx.Conj(a[j*lda+j])
  1442. }
  1443. bij += c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
  1444. }
  1445. bi[j] = alpha * bij
  1446. }
  1447. }
  1448. }
  1449. }
  1450. }
  1451. }
  1452. // Ztrsm solves one of the matrix equations
  1453. // op(A) * X = alpha * B if side == blas.Left,
  1454. // X * op(A) = alpha * B if side == blas.Right,
  1455. // where alpha is a scalar, X and B are m×n matrices, A is a unit or
  1456. // non-unit, upper or lower triangular matrix and op(A) is one of
  1457. // op(A) = A if transA == blas.NoTrans,
  1458. // op(A) = Aᵀ if transA == blas.Trans,
  1459. // op(A) = Aᴴ if transA == blas.ConjTrans.
  1460. // On return the matrix X is overwritten on B.
  1461. func (Implementation) Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
  1462. na := m
  1463. if side == blas.Right {
  1464. na = n
  1465. }
  1466. switch {
  1467. case side != blas.Left && side != blas.Right:
  1468. panic(badSide)
  1469. case uplo != blas.Lower && uplo != blas.Upper:
  1470. panic(badUplo)
  1471. case transA != blas.NoTrans && transA != blas.Trans && transA != blas.ConjTrans:
  1472. panic(badTranspose)
  1473. case diag != blas.Unit && diag != blas.NonUnit:
  1474. panic(badDiag)
  1475. case m < 0:
  1476. panic(mLT0)
  1477. case n < 0:
  1478. panic(nLT0)
  1479. case lda < max(1, na):
  1480. panic(badLdA)
  1481. case ldb < max(1, n):
  1482. panic(badLdB)
  1483. }
  1484. // Quick return if possible.
  1485. if m == 0 || n == 0 {
  1486. return
  1487. }
  1488. // For zero matrix size the following slice length checks are trivially satisfied.
  1489. if len(a) < (na-1)*lda+na {
  1490. panic(shortA)
  1491. }
  1492. if len(b) < (m-1)*ldb+n {
  1493. panic(shortB)
  1494. }
  1495. if alpha == 0 {
  1496. for i := 0; i < m; i++ {
  1497. for j := 0; j < n; j++ {
  1498. b[i*ldb+j] = 0
  1499. }
  1500. }
  1501. return
  1502. }
  1503. noConj := transA != blas.ConjTrans
  1504. noUnit := diag == blas.NonUnit
  1505. if side == blas.Left {
  1506. if transA == blas.NoTrans {
  1507. // Form B = alpha*inv(A)*B.
  1508. if uplo == blas.Upper {
  1509. for i := m - 1; i >= 0; i-- {
  1510. bi := b[i*ldb : i*ldb+n]
  1511. if alpha != 1 {
  1512. c128.ScalUnitary(alpha, bi)
  1513. }
  1514. for ka, aik := range a[i*lda+i+1 : i*lda+m] {
  1515. k := i + 1 + ka
  1516. if aik != 0 {
  1517. c128.AxpyUnitary(-aik, b[k*ldb:k*ldb+n], bi)
  1518. }
  1519. }
  1520. if noUnit {
  1521. c128.ScalUnitary(1/a[i*lda+i], bi)
  1522. }
  1523. }
  1524. } else {
  1525. for i := 0; i < m; i++ {
  1526. bi := b[i*ldb : i*ldb+n]
  1527. if alpha != 1 {
  1528. c128.ScalUnitary(alpha, bi)
  1529. }
  1530. for j, aij := range a[i*lda : i*lda+i] {
  1531. if aij != 0 {
  1532. c128.AxpyUnitary(-aij, b[j*ldb:j*ldb+n], bi)
  1533. }
  1534. }
  1535. if noUnit {
  1536. c128.ScalUnitary(1/a[i*lda+i], bi)
  1537. }
  1538. }
  1539. }
  1540. } else {
  1541. // Form B = alpha*inv(Aᵀ)*B or B = alpha*inv(Aᴴ)*B.
  1542. if uplo == blas.Upper {
  1543. for i := 0; i < m; i++ {
  1544. bi := b[i*ldb : i*ldb+n]
  1545. if noUnit {
  1546. if noConj {
  1547. c128.ScalUnitary(1/a[i*lda+i], bi)
  1548. } else {
  1549. c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
  1550. }
  1551. }
  1552. for ja, aij := range a[i*lda+i+1 : i*lda+m] {
  1553. if aij == 0 {
  1554. continue
  1555. }
  1556. j := i + 1 + ja
  1557. if noConj {
  1558. c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
  1559. } else {
  1560. c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
  1561. }
  1562. }
  1563. if alpha != 1 {
  1564. c128.ScalUnitary(alpha, bi)
  1565. }
  1566. }
  1567. } else {
  1568. for i := m - 1; i >= 0; i-- {
  1569. bi := b[i*ldb : i*ldb+n]
  1570. if noUnit {
  1571. if noConj {
  1572. c128.ScalUnitary(1/a[i*lda+i], bi)
  1573. } else {
  1574. c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
  1575. }
  1576. }
  1577. for j, aij := range a[i*lda : i*lda+i] {
  1578. if aij == 0 {
  1579. continue
  1580. }
  1581. if noConj {
  1582. c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
  1583. } else {
  1584. c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
  1585. }
  1586. }
  1587. if alpha != 1 {
  1588. c128.ScalUnitary(alpha, bi)
  1589. }
  1590. }
  1591. }
  1592. }
  1593. } else {
  1594. if transA == blas.NoTrans {
  1595. // Form B = alpha*B*inv(A).
  1596. if uplo == blas.Upper {
  1597. for i := 0; i < m; i++ {
  1598. bi := b[i*ldb : i*ldb+n]
  1599. if alpha != 1 {
  1600. c128.ScalUnitary(alpha, bi)
  1601. }
  1602. for j, bij := range bi {
  1603. if bij == 0 {
  1604. continue
  1605. }
  1606. if noUnit {
  1607. bi[j] /= a[j*lda+j]
  1608. }
  1609. c128.AxpyUnitary(-bi[j], a[j*lda+j+1:j*lda+n], bi[j+1:n])
  1610. }
  1611. }
  1612. } else {
  1613. for i := 0; i < m; i++ {
  1614. bi := b[i*ldb : i*ldb+n]
  1615. if alpha != 1 {
  1616. c128.ScalUnitary(alpha, bi)
  1617. }
  1618. for j := n - 1; j >= 0; j-- {
  1619. if bi[j] == 0 {
  1620. continue
  1621. }
  1622. if noUnit {
  1623. bi[j] /= a[j*lda+j]
  1624. }
  1625. c128.AxpyUnitary(-bi[j], a[j*lda:j*lda+j], bi[:j])
  1626. }
  1627. }
  1628. }
  1629. } else {
  1630. // Form B = alpha*B*inv(Aᵀ) or B = alpha*B*inv(Aᴴ).
  1631. if uplo == blas.Upper {
  1632. for i := 0; i < m; i++ {
  1633. bi := b[i*ldb : i*ldb+n]
  1634. for j := n - 1; j >= 0; j-- {
  1635. bij := alpha * bi[j]
  1636. if noConj {
  1637. bij -= c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
  1638. if noUnit {
  1639. bij /= a[j*lda+j]
  1640. }
  1641. } else {
  1642. bij -= c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
  1643. if noUnit {
  1644. bij /= cmplx.Conj(a[j*lda+j])
  1645. }
  1646. }
  1647. bi[j] = bij
  1648. }
  1649. }
  1650. } else {
  1651. for i := 0; i < m; i++ {
  1652. bi := b[i*ldb : i*ldb+n]
  1653. for j, bij := range bi {
  1654. bij *= alpha
  1655. if noConj {
  1656. bij -= c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
  1657. if noUnit {
  1658. bij /= a[j*lda+j]
  1659. }
  1660. } else {
  1661. bij -= c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
  1662. if noUnit {
  1663. bij /= cmplx.Conj(a[j*lda+j])
  1664. }
  1665. }
  1666. bi[j] = bij
  1667. }
  1668. }
  1669. }
  1670. }
  1671. }
  1672. }