123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728 |
- // Copyright ©2019 The Gonum Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package gonum
- import (
- "math/cmplx"
- "gonum.org/v1/gonum/blas"
- "gonum.org/v1/gonum/internal/asm/c128"
- )
- var _ blas.Complex128Level3 = Implementation{}
- // Zgemm performs one of the matrix-matrix operations
- // C = alpha * op(A) * op(B) + beta * C
- // where op(X) is one of
- // op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ,
- // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
- // op(B) a k×n matrix and C an m×n matrix.
- func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
- switch tA {
- default:
- panic(badTranspose)
- case blas.NoTrans, blas.Trans, blas.ConjTrans:
- }
- switch tB {
- default:
- panic(badTranspose)
- case blas.NoTrans, blas.Trans, blas.ConjTrans:
- }
- switch {
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- }
- rowA, colA := m, k
- if tA != blas.NoTrans {
- rowA, colA = k, m
- }
- if lda < max(1, colA) {
- panic(badLdA)
- }
- rowB, colB := k, n
- if tB != blas.NoTrans {
- rowB, colB = n, k
- }
- if ldb < max(1, colB) {
- panic(badLdB)
- }
- if ldc < max(1, n) {
- panic(badLdC)
- }
- // Quick return if possible.
- if m == 0 || n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (rowA-1)*lda+colA {
- panic(shortA)
- }
- if len(b) < (rowB-1)*ldb+colB {
- panic(shortB)
- }
- if len(c) < (m-1)*ldc+n {
- panic(shortC)
- }
- // Quick return if possible.
- if (alpha == 0 || k == 0) && beta == 1 {
- return
- }
- if alpha == 0 {
- if beta == 0 {
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- c[i*ldc+j] = 0
- }
- }
- } else {
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- c[i*ldc+j] *= beta
- }
- }
- }
- return
- }
- switch tA {
- case blas.NoTrans:
- switch tB {
- case blas.NoTrans:
- // Form C = alpha * A * B + beta * C.
- for i := 0; i < m; i++ {
- switch {
- case beta == 0:
- for j := 0; j < n; j++ {
- c[i*ldc+j] = 0
- }
- case beta != 1:
- for j := 0; j < n; j++ {
- c[i*ldc+j] *= beta
- }
- }
- for l := 0; l < k; l++ {
- tmp := alpha * a[i*lda+l]
- for j := 0; j < n; j++ {
- c[i*ldc+j] += tmp * b[l*ldb+j]
- }
- }
- }
- case blas.Trans:
- // Form C = alpha * A * Bᵀ + beta * C.
- for i := 0; i < m; i++ {
- switch {
- case beta == 0:
- for j := 0; j < n; j++ {
- c[i*ldc+j] = 0
- }
- case beta != 1:
- for j := 0; j < n; j++ {
- c[i*ldc+j] *= beta
- }
- }
- for l := 0; l < k; l++ {
- tmp := alpha * a[i*lda+l]
- for j := 0; j < n; j++ {
- c[i*ldc+j] += tmp * b[j*ldb+l]
- }
- }
- }
- case blas.ConjTrans:
- // Form C = alpha * A * Bᴴ + beta * C.
- for i := 0; i < m; i++ {
- switch {
- case beta == 0:
- for j := 0; j < n; j++ {
- c[i*ldc+j] = 0
- }
- case beta != 1:
- for j := 0; j < n; j++ {
- c[i*ldc+j] *= beta
- }
- }
- for l := 0; l < k; l++ {
- tmp := alpha * a[i*lda+l]
- for j := 0; j < n; j++ {
- c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l])
- }
- }
- }
- }
- case blas.Trans:
- switch tB {
- case blas.NoTrans:
- // Form C = alpha * Aᵀ * B + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += a[l*lda+i] * b[l*ldb+j]
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- case blas.Trans:
- // Form C = alpha * Aᵀ * Bᵀ + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += a[l*lda+i] * b[j*ldb+l]
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- case blas.ConjTrans:
- // Form C = alpha * Aᵀ * Bᴴ + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- }
- case blas.ConjTrans:
- switch tB {
- case blas.NoTrans:
- // Form C = alpha * Aᴴ * B + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- case blas.Trans:
- // Form C = alpha * Aᴴ * Bᵀ + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- case blas.ConjTrans:
- // Form C = alpha * Aᴴ * Bᴴ + beta * C.
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- var tmp complex128
- for l := 0; l < k; l++ {
- tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
- }
- if beta == 0 {
- c[i*ldc+j] = alpha * tmp
- } else {
- c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- }
- }
- }
- // Zhemm performs one of the matrix-matrix operations
- // C = alpha*A*B + beta*C if side == blas.Left
- // C = alpha*B*A + beta*C if side == blas.Right
- // where alpha and beta are scalars, A is an m×m or n×n hermitian matrix and B
- // and C are m×n matrices. The imaginary parts of the diagonal elements of A are
- // assumed to be zero.
- func (Implementation) Zhemm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
- na := m
- if side == blas.Right {
- na = n
- }
- switch {
- case side != blas.Left && side != blas.Right:
- panic(badSide)
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case lda < max(1, na):
- panic(badLdA)
- case ldb < max(1, n):
- panic(badLdB)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if m == 0 || n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < lda*(na-1)+na {
- panic(shortA)
- }
- if len(b) < ldb*(m-1)+n {
- panic(shortB)
- }
- if len(c) < ldc*(m-1)+n {
- panic(shortC)
- }
- // Quick return if possible.
- if alpha == 0 && beta == 1 {
- return
- }
- if alpha == 0 {
- if beta == 0 {
- for i := 0; i < m; i++ {
- ci := c[i*ldc : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < m; i++ {
- ci := c[i*ldc : i*ldc+n]
- c128.ScalUnitary(beta, ci)
- }
- }
- return
- }
- if side == blas.Left {
- // Form C = alpha*A*B + beta*C.
- for i := 0; i < m; i++ {
- atmp := alpha * complex(real(a[i*lda+i]), 0)
- bi := b[i*ldb : i*ldb+n]
- ci := c[i*ldc : i*ldc+n]
- if beta == 0 {
- for j, bij := range bi {
- ci[j] = atmp * bij
- }
- } else {
- for j, bij := range bi {
- ci[j] = atmp*bij + beta*ci[j]
- }
- }
- if uplo == blas.Upper {
- for k := 0; k < i; k++ {
- atmp = alpha * cmplx.Conj(a[k*lda+i])
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- for k := i + 1; k < m; k++ {
- atmp = alpha * a[i*lda+k]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- } else {
- for k := 0; k < i; k++ {
- atmp = alpha * a[i*lda+k]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- for k := i + 1; k < m; k++ {
- atmp = alpha * cmplx.Conj(a[k*lda+i])
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- }
- }
- } else {
- // Form C = alpha*B*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- for j := n - 1; j >= 0; j-- {
- abij := alpha * b[i*ldb+j]
- aj := a[j*lda+j+1 : j*lda+n]
- bi := b[i*ldb+j+1 : i*ldb+n]
- ci := c[i*ldc+j+1 : i*ldc+n]
- var tmp complex128
- for k, ajk := range aj {
- ci[k] += abij * ajk
- tmp += bi[k] * cmplx.Conj(ajk)
- }
- ajj := complex(real(a[j*lda+j]), 0)
- if beta == 0 {
- c[i*ldc+j] = abij*ajj + alpha*tmp
- } else {
- c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- } else {
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- abij := alpha * b[i*ldb+j]
- aj := a[j*lda : j*lda+j]
- bi := b[i*ldb : i*ldb+j]
- ci := c[i*ldc : i*ldc+j]
- var tmp complex128
- for k, ajk := range aj {
- ci[k] += abij * ajk
- tmp += bi[k] * cmplx.Conj(ajk)
- }
- ajj := complex(real(a[j*lda+j]), 0)
- if beta == 0 {
- c[i*ldc+j] = abij*ajj + alpha*tmp
- } else {
- c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- }
- }
- }
- // Zherk performs one of the hermitian rank-k operations
- // C = alpha*A*Aᴴ + beta*C if trans == blas.NoTrans
- // C = alpha*Aᴴ*A + beta*C if trans == blas.ConjTrans
- // where alpha and beta are real scalars, C is an n×n hermitian matrix and A is
- // an n×k matrix in the first case and a k×n matrix in the second case.
- //
- // The imaginary parts of the diagonal elements of C are assumed to be zero, and
- // on return they will be set to zero.
- func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) {
- var rowA, colA int
- switch trans {
- default:
- panic(badTranspose)
- case blas.NoTrans:
- rowA, colA = n, k
- case blas.ConjTrans:
- rowA, colA = k, n
- }
- switch {
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- case lda < max(1, colA):
- panic(badLdA)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (rowA-1)*lda+colA {
- panic(shortA)
- }
- if len(c) < (n-1)*ldc+n {
- panic(shortC)
- }
- // Quick return if possible.
- if (alpha == 0 || k == 0) && beta == 1 {
- return
- }
- if alpha == 0 {
- if uplo == blas.Upper {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- ci[0] = complex(beta*real(ci[0]), 0)
- if i != n-1 {
- c128.DscalUnitary(beta, ci[1:])
- }
- }
- }
- } else {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- if i != 0 {
- c128.DscalUnitary(beta, ci[:i])
- }
- ci[i] = complex(beta*real(ci[i]), 0)
- }
- }
- }
- return
- }
- calpha := complex(alpha, 0)
- if trans == blas.NoTrans {
- // Form C = alpha*A*Aᴴ + beta*C.
- cbeta := complex(beta, 0)
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- ai := a[i*lda : i*lda+k]
- switch {
- case beta == 0:
- // Handle the i-th diagonal element of C.
- ci[0] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
- // Handle the remaining elements on the i-th row of C.
- for jc := range ci[1:] {
- j := i + 1 + jc
- ci[jc+1] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
- }
- case beta != 1:
- cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[0]
- ci[0] = complex(real(cii), 0)
- for jc, cij := range ci[1:] {
- j := i + 1 + jc
- ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
- }
- default:
- cii := calpha*c128.DotcUnitary(ai, ai) + ci[0]
- ci[0] = complex(real(cii), 0)
- for jc, cij := range ci[1:] {
- j := i + 1 + jc
- ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- ai := a[i*lda : i*lda+k]
- switch {
- case beta == 0:
- // Handle the first i-1 elements on the i-th row of C.
- for j := range ci[:i] {
- ci[j] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
- }
- // Handle the i-th diagonal element of C.
- ci[i] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
- case beta != 1:
- for j, cij := range ci[:i] {
- ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
- }
- cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[i]
- ci[i] = complex(real(cii), 0)
- default:
- for j, cij := range ci[:i] {
- ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
- }
- cii := calpha*c128.DotcUnitary(ai, ai) + ci[i]
- ci[i] = complex(real(cii), 0)
- }
- }
- }
- } else {
- // Form C = alpha*Aᴴ*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- switch {
- case beta == 0:
- for jc := range ci {
- ci[jc] = 0
- }
- case beta != 1:
- c128.DscalUnitary(beta, ci)
- ci[0] = complex(real(ci[0]), 0)
- default:
- ci[0] = complex(real(ci[0]), 0)
- }
- for j := 0; j < k; j++ {
- aji := cmplx.Conj(a[j*lda+i])
- if aji != 0 {
- c128.AxpyUnitary(calpha*aji, a[j*lda+i:j*lda+n], ci)
- }
- }
- c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- switch {
- case beta == 0:
- for j := range ci {
- ci[j] = 0
- }
- case beta != 1:
- c128.DscalUnitary(beta, ci)
- ci[i] = complex(real(ci[i]), 0)
- default:
- ci[i] = complex(real(ci[i]), 0)
- }
- for j := 0; j < k; j++ {
- aji := cmplx.Conj(a[j*lda+i])
- if aji != 0 {
- c128.AxpyUnitary(calpha*aji, a[j*lda:j*lda+i+1], ci)
- }
- }
- c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
- }
- }
- }
- }
- // Zher2k performs one of the hermitian rank-2k operations
- // C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C if trans == blas.NoTrans
- // C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C if trans == blas.ConjTrans
- // where alpha and beta are scalars with beta real, C is an n×n hermitian matrix
- // and A and B are n×k matrices in the first case and k×n matrices in the second case.
- //
- // The imaginary parts of the diagonal elements of C are assumed to be zero, and
- // on return they will be set to zero.
- func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) {
- var row, col int
- switch trans {
- default:
- panic(badTranspose)
- case blas.NoTrans:
- row, col = n, k
- case blas.ConjTrans:
- row, col = k, n
- }
- switch {
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- case lda < max(1, col):
- panic(badLdA)
- case ldb < max(1, col):
- panic(badLdB)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (row-1)*lda+col {
- panic(shortA)
- }
- if len(b) < (row-1)*ldb+col {
- panic(shortB)
- }
- if len(c) < (n-1)*ldc+n {
- panic(shortC)
- }
- // Quick return if possible.
- if (alpha == 0 || k == 0) && beta == 1 {
- return
- }
- if alpha == 0 {
- if uplo == blas.Upper {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- ci[0] = complex(beta*real(ci[0]), 0)
- if i != n-1 {
- c128.DscalUnitary(beta, ci[1:])
- }
- }
- }
- } else {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- if i != 0 {
- c128.DscalUnitary(beta, ci[:i])
- }
- ci[i] = complex(beta*real(ci[i]), 0)
- }
- }
- }
- return
- }
- conjalpha := cmplx.Conj(alpha)
- cbeta := complex(beta, 0)
- if trans == blas.NoTrans {
- // Form C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i+1 : i*ldc+n]
- ai := a[i*lda : i*lda+k]
- bi := b[i*ldb : i*ldb+k]
- if beta == 0 {
- cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
- c[i*ldc+i] = complex(real(cii), 0)
- for jc := range ci {
- j := i + 1 + jc
- ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
- }
- } else {
- cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
- c[i*ldc+i] = complex(real(cii), 0)
- for jc, cij := range ci {
- j := i + 1 + jc
- ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i]
- ai := a[i*lda : i*lda+k]
- bi := b[i*ldb : i*ldb+k]
- if beta == 0 {
- for j := range ci {
- ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
- }
- cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
- c[i*ldc+i] = complex(real(cii), 0)
- } else {
- for j, cij := range ci {
- ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
- }
- cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
- c[i*ldc+i] = complex(real(cii), 0)
- }
- }
- }
- } else {
- // Form C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- switch {
- case beta == 0:
- for jc := range ci {
- ci[jc] = 0
- }
- case beta != 1:
- c128.DscalUnitary(beta, ci)
- ci[0] = complex(real(ci[0]), 0)
- default:
- ci[0] = complex(real(ci[0]), 0)
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- bji := b[j*ldb+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci)
- }
- if bji != 0 {
- c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci)
- }
- }
- ci[0] = complex(real(ci[0]), 0)
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- switch {
- case beta == 0:
- for j := range ci {
- ci[j] = 0
- }
- case beta != 1:
- c128.DscalUnitary(beta, ci)
- ci[i] = complex(real(ci[i]), 0)
- default:
- ci[i] = complex(real(ci[i]), 0)
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- bji := b[j*ldb+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci)
- }
- if bji != 0 {
- c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci)
- }
- }
- ci[i] = complex(real(ci[i]), 0)
- }
- }
- }
- }
- // Zsymm performs one of the matrix-matrix operations
- // C = alpha*A*B + beta*C if side == blas.Left
- // C = alpha*B*A + beta*C if side == blas.Right
- // where alpha and beta are scalars, A is an m×m or n×n symmetric matrix and B
- // and C are m×n matrices.
- func (Implementation) Zsymm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
- na := m
- if side == blas.Right {
- na = n
- }
- switch {
- case side != blas.Left && side != blas.Right:
- panic(badSide)
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case lda < max(1, na):
- panic(badLdA)
- case ldb < max(1, n):
- panic(badLdB)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if m == 0 || n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < lda*(na-1)+na {
- panic(shortA)
- }
- if len(b) < ldb*(m-1)+n {
- panic(shortB)
- }
- if len(c) < ldc*(m-1)+n {
- panic(shortC)
- }
- // Quick return if possible.
- if alpha == 0 && beta == 1 {
- return
- }
- if alpha == 0 {
- if beta == 0 {
- for i := 0; i < m; i++ {
- ci := c[i*ldc : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < m; i++ {
- ci := c[i*ldc : i*ldc+n]
- c128.ScalUnitary(beta, ci)
- }
- }
- return
- }
- if side == blas.Left {
- // Form C = alpha*A*B + beta*C.
- for i := 0; i < m; i++ {
- atmp := alpha * a[i*lda+i]
- bi := b[i*ldb : i*ldb+n]
- ci := c[i*ldc : i*ldc+n]
- if beta == 0 {
- for j, bij := range bi {
- ci[j] = atmp * bij
- }
- } else {
- for j, bij := range bi {
- ci[j] = atmp*bij + beta*ci[j]
- }
- }
- if uplo == blas.Upper {
- for k := 0; k < i; k++ {
- atmp = alpha * a[k*lda+i]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- for k := i + 1; k < m; k++ {
- atmp = alpha * a[i*lda+k]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- } else {
- for k := 0; k < i; k++ {
- atmp = alpha * a[i*lda+k]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- for k := i + 1; k < m; k++ {
- atmp = alpha * a[k*lda+i]
- c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
- }
- }
- }
- } else {
- // Form C = alpha*B*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- for j := n - 1; j >= 0; j-- {
- abij := alpha * b[i*ldb+j]
- aj := a[j*lda+j+1 : j*lda+n]
- bi := b[i*ldb+j+1 : i*ldb+n]
- ci := c[i*ldc+j+1 : i*ldc+n]
- var tmp complex128
- for k, ajk := range aj {
- ci[k] += abij * ajk
- tmp += bi[k] * ajk
- }
- if beta == 0 {
- c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
- } else {
- c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- } else {
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- abij := alpha * b[i*ldb+j]
- aj := a[j*lda : j*lda+j]
- bi := b[i*ldb : i*ldb+j]
- ci := c[i*ldc : i*ldc+j]
- var tmp complex128
- for k, ajk := range aj {
- ci[k] += abij * ajk
- tmp += bi[k] * ajk
- }
- if beta == 0 {
- c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
- } else {
- c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
- }
- }
- }
- }
- }
- }
- // Zsyrk performs one of the symmetric rank-k operations
- // C = alpha*A*Aᵀ + beta*C if trans == blas.NoTrans
- // C = alpha*Aᵀ*A + beta*C if trans == blas.Trans
- // where alpha and beta are scalars, C is an n×n symmetric matrix and A is
- // an n×k matrix in the first case and a k×n matrix in the second case.
- func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) {
- var rowA, colA int
- switch trans {
- default:
- panic(badTranspose)
- case blas.NoTrans:
- rowA, colA = n, k
- case blas.Trans:
- rowA, colA = k, n
- }
- switch {
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- case lda < max(1, colA):
- panic(badLdA)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (rowA-1)*lda+colA {
- panic(shortA)
- }
- if len(c) < (n-1)*ldc+n {
- panic(shortC)
- }
- // Quick return if possible.
- if (alpha == 0 || k == 0) && beta == 1 {
- return
- }
- if alpha == 0 {
- if uplo == blas.Upper {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- c128.ScalUnitary(beta, ci)
- }
- }
- } else {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- c128.ScalUnitary(beta, ci)
- }
- }
- }
- return
- }
- if trans == blas.NoTrans {
- // Form C = alpha*A*Aᵀ + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- ai := a[i*lda : i*lda+k]
- if beta == 0 {
- for jc := range ci {
- j := i + jc
- ci[jc] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
- }
- } else {
- for jc, cij := range ci {
- j := i + jc
- ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- ai := a[i*lda : i*lda+k]
- if beta == 0 {
- for j := range ci {
- ci[j] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
- }
- } else {
- for j, cij := range ci {
- ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
- }
- }
- }
- }
- } else {
- // Form C = alpha*Aᵀ*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- switch {
- case beta == 0:
- for jc := range ci {
- ci[jc] = 0
- }
- case beta != 1:
- for jc := range ci {
- ci[jc] *= beta
- }
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci)
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- switch {
- case beta == 0:
- for j := range ci {
- ci[j] = 0
- }
- case beta != 1:
- for j := range ci {
- ci[j] *= beta
- }
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci)
- }
- }
- }
- }
- }
- }
- // Zsyr2k performs one of the symmetric rank-2k operations
- // C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C if trans == blas.NoTrans
- // C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C if trans == blas.Trans
- // where alpha and beta are scalars, C is an n×n symmetric matrix and A and B
- // are n×k matrices in the first case and k×n matrices in the second case.
- func (Implementation) Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
- var row, col int
- switch trans {
- default:
- panic(badTranspose)
- case blas.NoTrans:
- row, col = n, k
- case blas.Trans:
- row, col = k, n
- }
- switch {
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- case lda < max(1, col):
- panic(badLdA)
- case ldb < max(1, col):
- panic(badLdB)
- case ldc < max(1, n):
- panic(badLdC)
- }
- // Quick return if possible.
- if n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (row-1)*lda+col {
- panic(shortA)
- }
- if len(b) < (row-1)*ldb+col {
- panic(shortB)
- }
- if len(c) < (n-1)*ldc+n {
- panic(shortC)
- }
- // Quick return if possible.
- if (alpha == 0 || k == 0) && beta == 1 {
- return
- }
- if alpha == 0 {
- if uplo == blas.Upper {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- c128.ScalUnitary(beta, ci)
- }
- }
- } else {
- if beta == 0 {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- for j := range ci {
- ci[j] = 0
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- c128.ScalUnitary(beta, ci)
- }
- }
- }
- return
- }
- if trans == blas.NoTrans {
- // Form C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- ai := a[i*lda : i*lda+k]
- bi := b[i*ldb : i*ldb+k]
- if beta == 0 {
- for jc := range ci {
- j := i + jc
- ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
- }
- } else {
- for jc, cij := range ci {
- j := i + jc
- ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- ai := a[i*lda : i*lda+k]
- bi := b[i*ldb : i*ldb+k]
- if beta == 0 {
- for j := range ci {
- ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
- }
- } else {
- for j, cij := range ci {
- ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
- }
- }
- }
- }
- } else {
- // Form C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C.
- if uplo == blas.Upper {
- for i := 0; i < n; i++ {
- ci := c[i*ldc+i : i*ldc+n]
- switch {
- case beta == 0:
- for jc := range ci {
- ci[jc] = 0
- }
- case beta != 1:
- for jc := range ci {
- ci[jc] *= beta
- }
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- bji := b[j*ldb+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*aji, b[j*ldb+i:j*ldb+n], ci)
- }
- if bji != 0 {
- c128.AxpyUnitary(alpha*bji, a[j*lda+i:j*lda+n], ci)
- }
- }
- }
- } else {
- for i := 0; i < n; i++ {
- ci := c[i*ldc : i*ldc+i+1]
- switch {
- case beta == 0:
- for j := range ci {
- ci[j] = 0
- }
- case beta != 1:
- for j := range ci {
- ci[j] *= beta
- }
- }
- for j := 0; j < k; j++ {
- aji := a[j*lda+i]
- bji := b[j*ldb+i]
- if aji != 0 {
- c128.AxpyUnitary(alpha*aji, b[j*ldb:j*ldb+i+1], ci)
- }
- if bji != 0 {
- c128.AxpyUnitary(alpha*bji, a[j*lda:j*lda+i+1], ci)
- }
- }
- }
- }
- }
- }
- // Ztrmm performs one of the matrix-matrix operations
- // B = alpha * op(A) * B if side == blas.Left,
- // B = alpha * B * op(A) if side == blas.Right,
- // where alpha is a scalar, B is an m×n matrix, A is a unit, or non-unit,
- // upper or lower triangular matrix and op(A) is one of
- // op(A) = A if trans == blas.NoTrans,
- // op(A) = Aᵀ if trans == blas.Trans,
- // op(A) = Aᴴ if trans == blas.ConjTrans.
- func (Implementation) Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
- na := m
- if side == blas.Right {
- na = n
- }
- switch {
- case side != blas.Left && side != blas.Right:
- panic(badSide)
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
- panic(badTranspose)
- case diag != blas.Unit && diag != blas.NonUnit:
- panic(badDiag)
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case lda < max(1, na):
- panic(badLdA)
- case ldb < max(1, n):
- panic(badLdB)
- }
- // Quick return if possible.
- if m == 0 || n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (na-1)*lda+na {
- panic(shortA)
- }
- if len(b) < (m-1)*ldb+n {
- panic(shortB)
- }
- // Quick return if possible.
- if alpha == 0 {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for j := range bi {
- bi[j] = 0
- }
- }
- return
- }
- noConj := trans != blas.ConjTrans
- noUnit := diag == blas.NonUnit
- if side == blas.Left {
- if trans == blas.NoTrans {
- // Form B = alpha*A*B.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- aii := alpha
- if noUnit {
- aii *= a[i*lda+i]
- }
- bi := b[i*ldb : i*ldb+n]
- for j := range bi {
- bi[j] *= aii
- }
- for ja, aij := range a[i*lda+i+1 : i*lda+m] {
- j := ja + i + 1
- if aij != 0 {
- c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
- }
- }
- }
- } else {
- for i := m - 1; i >= 0; i-- {
- aii := alpha
- if noUnit {
- aii *= a[i*lda+i]
- }
- bi := b[i*ldb : i*ldb+n]
- for j := range bi {
- bi[j] *= aii
- }
- for j, aij := range a[i*lda : i*lda+i] {
- if aij != 0 {
- c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
- }
- }
- }
- }
- } else {
- // Form B = alpha*Aᵀ*B or B = alpha*Aᴴ*B.
- if uplo == blas.Upper {
- for k := m - 1; k >= 0; k-- {
- bk := b[k*ldb : k*ldb+n]
- for ja, ajk := range a[k*lda+k+1 : k*lda+m] {
- if ajk == 0 {
- continue
- }
- j := k + 1 + ja
- if noConj {
- c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
- } else {
- c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
- }
- }
- akk := alpha
- if noUnit {
- if noConj {
- akk *= a[k*lda+k]
- } else {
- akk *= cmplx.Conj(a[k*lda+k])
- }
- }
- if akk != 1 {
- c128.ScalUnitary(akk, bk)
- }
- }
- } else {
- for k := 0; k < m; k++ {
- bk := b[k*ldb : k*ldb+n]
- for j, ajk := range a[k*lda : k*lda+k] {
- if ajk == 0 {
- continue
- }
- if noConj {
- c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
- } else {
- c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
- }
- }
- akk := alpha
- if noUnit {
- if noConj {
- akk *= a[k*lda+k]
- } else {
- akk *= cmplx.Conj(a[k*lda+k])
- }
- }
- if akk != 1 {
- c128.ScalUnitary(akk, bk)
- }
- }
- }
- }
- } else {
- if trans == blas.NoTrans {
- // Form B = alpha*B*A.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for k := n - 1; k >= 0; k-- {
- abik := alpha * bi[k]
- if abik == 0 {
- continue
- }
- bi[k] = abik
- if noUnit {
- bi[k] *= a[k*lda+k]
- }
- c128.AxpyUnitary(abik, a[k*lda+k+1:k*lda+n], bi[k+1:])
- }
- }
- } else {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for k := 0; k < n; k++ {
- abik := alpha * bi[k]
- if abik == 0 {
- continue
- }
- bi[k] = abik
- if noUnit {
- bi[k] *= a[k*lda+k]
- }
- c128.AxpyUnitary(abik, a[k*lda:k*lda+k], bi[:k])
- }
- }
- }
- } else {
- // Form B = alpha*B*Aᵀ or B = alpha*B*Aᴴ.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for j, bij := range bi {
- if noConj {
- if noUnit {
- bij *= a[j*lda+j]
- }
- bij += c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
- } else {
- if noUnit {
- bij *= cmplx.Conj(a[j*lda+j])
- }
- bij += c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
- }
- bi[j] = alpha * bij
- }
- }
- } else {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for j := n - 1; j >= 0; j-- {
- bij := bi[j]
- if noConj {
- if noUnit {
- bij *= a[j*lda+j]
- }
- bij += c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
- } else {
- if noUnit {
- bij *= cmplx.Conj(a[j*lda+j])
- }
- bij += c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
- }
- bi[j] = alpha * bij
- }
- }
- }
- }
- }
- }
- // Ztrsm solves one of the matrix equations
- // op(A) * X = alpha * B if side == blas.Left,
- // X * op(A) = alpha * B if side == blas.Right,
- // where alpha is a scalar, X and B are m×n matrices, A is a unit or
- // non-unit, upper or lower triangular matrix and op(A) is one of
- // op(A) = A if transA == blas.NoTrans,
- // op(A) = Aᵀ if transA == blas.Trans,
- // op(A) = Aᴴ if transA == blas.ConjTrans.
- // On return the matrix X is overwritten on B.
- func (Implementation) Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
- na := m
- if side == blas.Right {
- na = n
- }
- switch {
- case side != blas.Left && side != blas.Right:
- panic(badSide)
- case uplo != blas.Lower && uplo != blas.Upper:
- panic(badUplo)
- case transA != blas.NoTrans && transA != blas.Trans && transA != blas.ConjTrans:
- panic(badTranspose)
- case diag != blas.Unit && diag != blas.NonUnit:
- panic(badDiag)
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case lda < max(1, na):
- panic(badLdA)
- case ldb < max(1, n):
- panic(badLdB)
- }
- // Quick return if possible.
- if m == 0 || n == 0 {
- return
- }
- // For zero matrix size the following slice length checks are trivially satisfied.
- if len(a) < (na-1)*lda+na {
- panic(shortA)
- }
- if len(b) < (m-1)*ldb+n {
- panic(shortB)
- }
- if alpha == 0 {
- for i := 0; i < m; i++ {
- for j := 0; j < n; j++ {
- b[i*ldb+j] = 0
- }
- }
- return
- }
- noConj := transA != blas.ConjTrans
- noUnit := diag == blas.NonUnit
- if side == blas.Left {
- if transA == blas.NoTrans {
- // Form B = alpha*inv(A)*B.
- if uplo == blas.Upper {
- for i := m - 1; i >= 0; i-- {
- bi := b[i*ldb : i*ldb+n]
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- for ka, aik := range a[i*lda+i+1 : i*lda+m] {
- k := i + 1 + ka
- if aik != 0 {
- c128.AxpyUnitary(-aik, b[k*ldb:k*ldb+n], bi)
- }
- }
- if noUnit {
- c128.ScalUnitary(1/a[i*lda+i], bi)
- }
- }
- } else {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- for j, aij := range a[i*lda : i*lda+i] {
- if aij != 0 {
- c128.AxpyUnitary(-aij, b[j*ldb:j*ldb+n], bi)
- }
- }
- if noUnit {
- c128.ScalUnitary(1/a[i*lda+i], bi)
- }
- }
- }
- } else {
- // Form B = alpha*inv(Aᵀ)*B or B = alpha*inv(Aᴴ)*B.
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- if noUnit {
- if noConj {
- c128.ScalUnitary(1/a[i*lda+i], bi)
- } else {
- c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
- }
- }
- for ja, aij := range a[i*lda+i+1 : i*lda+m] {
- if aij == 0 {
- continue
- }
- j := i + 1 + ja
- if noConj {
- c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
- } else {
- c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
- }
- }
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- }
- } else {
- for i := m - 1; i >= 0; i-- {
- bi := b[i*ldb : i*ldb+n]
- if noUnit {
- if noConj {
- c128.ScalUnitary(1/a[i*lda+i], bi)
- } else {
- c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
- }
- }
- for j, aij := range a[i*lda : i*lda+i] {
- if aij == 0 {
- continue
- }
- if noConj {
- c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
- } else {
- c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
- }
- }
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- }
- }
- }
- } else {
- if transA == blas.NoTrans {
- // Form B = alpha*B*inv(A).
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- for j, bij := range bi {
- if bij == 0 {
- continue
- }
- if noUnit {
- bi[j] /= a[j*lda+j]
- }
- c128.AxpyUnitary(-bi[j], a[j*lda+j+1:j*lda+n], bi[j+1:n])
- }
- }
- } else {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- if alpha != 1 {
- c128.ScalUnitary(alpha, bi)
- }
- for j := n - 1; j >= 0; j-- {
- if bi[j] == 0 {
- continue
- }
- if noUnit {
- bi[j] /= a[j*lda+j]
- }
- c128.AxpyUnitary(-bi[j], a[j*lda:j*lda+j], bi[:j])
- }
- }
- }
- } else {
- // Form B = alpha*B*inv(Aᵀ) or B = alpha*B*inv(Aᴴ).
- if uplo == blas.Upper {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for j := n - 1; j >= 0; j-- {
- bij := alpha * bi[j]
- if noConj {
- bij -= c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
- if noUnit {
- bij /= a[j*lda+j]
- }
- } else {
- bij -= c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
- if noUnit {
- bij /= cmplx.Conj(a[j*lda+j])
- }
- }
- bi[j] = bij
- }
- }
- } else {
- for i := 0; i < m; i++ {
- bi := b[i*ldb : i*ldb+n]
- for j, bij := range bi {
- bij *= alpha
- if noConj {
- bij -= c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
- if noUnit {
- bij /= a[j*lda+j]
- }
- } else {
- bij -= c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
- if noUnit {
- bij /= cmplx.Conj(a[j*lda+j])
- }
- }
- bi[j] = bij
- }
- }
- }
- }
- }
- }
|