dlauum.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // Copyright ©2018 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gonum
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. )
  9. // Dlauum computes the product
  10. // U * Uᵀ if uplo is blas.Upper
  11. // Lᵀ * L if uplo is blas.Lower
  12. // where U or L is stored in the upper or lower triangular part of A.
  13. // Only the upper or lower triangle of the result is stored, overwriting
  14. // the corresponding factor in A.
  15. func (impl Implementation) Dlauum(uplo blas.Uplo, n int, a []float64, lda int) {
  16. switch {
  17. case uplo != blas.Upper && uplo != blas.Lower:
  18. panic(badUplo)
  19. case n < 0:
  20. panic(nLT0)
  21. case lda < max(1, n):
  22. panic(badLdA)
  23. }
  24. // Quick return if possible.
  25. if n == 0 {
  26. return
  27. }
  28. if len(a) < (n-1)*lda+n {
  29. panic(shortA)
  30. }
  31. // Determine the block size.
  32. opts := "U"
  33. if uplo == blas.Lower {
  34. opts = "L"
  35. }
  36. nb := impl.Ilaenv(1, "DLAUUM", opts, n, -1, -1, -1)
  37. if nb <= 1 || n <= nb {
  38. // Use unblocked code.
  39. impl.Dlauu2(uplo, n, a, lda)
  40. return
  41. }
  42. // Use blocked code.
  43. bi := blas64.Implementation()
  44. if uplo == blas.Upper {
  45. // Compute the product U*Uᵀ.
  46. for i := 0; i < n; i += nb {
  47. ib := min(nb, n-i)
  48. bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.NonUnit,
  49. i, ib, 1, a[i*lda+i:], lda, a[i:], lda)
  50. impl.Dlauu2(blas.Upper, ib, a[i*lda+i:], lda)
  51. if n-i-ib > 0 {
  52. bi.Dgemm(blas.NoTrans, blas.Trans, i, ib, n-i-ib,
  53. 1, a[i+ib:], lda, a[i*lda+i+ib:], lda, 1, a[i:], lda)
  54. bi.Dsyrk(blas.Upper, blas.NoTrans, ib, n-i-ib,
  55. 1, a[i*lda+i+ib:], lda, 1, a[i*lda+i:], lda)
  56. }
  57. }
  58. } else {
  59. // Compute the product Lᵀ*L.
  60. for i := 0; i < n; i += nb {
  61. ib := min(nb, n-i)
  62. bi.Dtrmm(blas.Left, blas.Lower, blas.Trans, blas.NonUnit,
  63. ib, i, 1, a[i*lda+i:], lda, a[i*lda:], lda)
  64. impl.Dlauu2(blas.Lower, ib, a[i*lda+i:], lda)
  65. if n-i-ib > 0 {
  66. bi.Dgemm(blas.Trans, blas.NoTrans, ib, i, n-i-ib,
  67. 1, a[(i+ib)*lda+i:], lda, a[(i+ib)*lda:], lda, 1, a[i*lda:], lda)
  68. bi.Dsyrk(blas.Lower, blas.Trans, ib, n-i-ib,
  69. 1, a[(i+ib)*lda+i:], lda, 1, a[i*lda+i:], lda)
  70. }
  71. }
  72. }
  73. }