123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- // Copyright ©2013 The Gonum Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package mat
- import (
- "math"
- "gonum.org/v1/gonum/blas"
- "gonum.org/v1/gonum/blas/blas64"
- "gonum.org/v1/gonum/floats"
- "gonum.org/v1/gonum/lapack"
- "gonum.org/v1/gonum/lapack/lapack64"
- )
- const (
- badSliceLength = "mat: improper slice length"
- badLU = "mat: invalid LU factorization"
- )
- // LU is a type for creating and using the LU factorization of a matrix.
- type LU struct {
- lu *Dense
- pivot []int
- cond float64
- }
- // updateCond updates the stored condition number of the matrix. anorm is the
- // norm of the original matrix. If anorm is negative it will be estimated.
- func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) {
- n := lu.lu.mat.Cols
- work := getFloats(4*n, false)
- defer putFloats(work)
- iwork := getInts(n, false)
- defer putInts(iwork)
- if anorm < 0 {
- // This is an approximation. By the definition of a norm,
- // |AB| <= |A| |B|.
- // Since A = L*U, we get for the condition number κ that
- // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|,
- // so this will overestimate the condition number somewhat.
- // The norm of the original factorized matrix cannot be stored
- // because of update possibilities.
- u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper)
- l := lu.lu.asTriDense(n, blas.Unit, blas.Lower)
- unorm := lapack64.Lantr(norm, u.mat, work)
- lnorm := lapack64.Lantr(norm, l.mat, work)
- anorm = unorm * lnorm
- }
- v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork)
- lu.cond = 1 / v
- }
- // Factorize computes the LU factorization of the square matrix a and stores the
- // result. The LU decomposition will complete regardless of the singularity of a.
- //
- // The LU factorization is computed with pivoting, and so really the decomposition
- // is a PLU decomposition where P is a permutation matrix. The individual matrix
- // factors can be extracted from the factorization using the Permutation method
- // on Dense, and the LU.LTo and LU.UTo methods.
- func (lu *LU) Factorize(a Matrix) {
- lu.factorize(a, CondNorm)
- }
- func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) {
- r, c := a.Dims()
- if r != c {
- panic(ErrSquare)
- }
- if lu.lu == nil {
- lu.lu = NewDense(r, r, nil)
- } else {
- lu.lu.Reset()
- lu.lu.reuseAsNonZeroed(r, r)
- }
- lu.lu.Copy(a)
- if cap(lu.pivot) < r {
- lu.pivot = make([]int, r)
- }
- lu.pivot = lu.pivot[:r]
- work := getFloats(r, false)
- anorm := lapack64.Lange(norm, lu.lu.mat, work)
- putFloats(work)
- lapack64.Getrf(lu.lu.mat, lu.pivot)
- lu.updateCond(anorm, norm)
- }
- // isValid returns whether the receiver contains a factorization.
- func (lu *LU) isValid() bool {
- return lu.lu != nil && !lu.lu.IsEmpty()
- }
- // Cond returns the condition number for the factorized matrix.
- // Cond will panic if the receiver does not contain a factorization.
- func (lu *LU) Cond() float64 {
- if !lu.isValid() {
- panic(badLU)
- }
- return lu.cond
- }
- // Reset resets the factorization so that it can be reused as the receiver of a
- // dimensionally restricted operation.
- func (lu *LU) Reset() {
- if lu.lu != nil {
- lu.lu.Reset()
- }
- lu.pivot = lu.pivot[:0]
- }
- func (lu *LU) isZero() bool {
- return len(lu.pivot) == 0
- }
- // Det returns the determinant of the matrix that has been factorized. In many
- // expressions, using LogDet will be more numerically stable.
- // Det will panic if the receiver does not contain a factorization.
- func (lu *LU) Det() float64 {
- det, sign := lu.LogDet()
- return math.Exp(det) * sign
- }
- // LogDet returns the log of the determinant and the sign of the determinant
- // for the matrix that has been factorized. Numerical stability in product and
- // division expressions is generally improved by working in log space.
- // LogDet will panic if the receiver does not contain a factorization.
- func (lu *LU) LogDet() (det float64, sign float64) {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- logDiag := getFloats(n, false)
- defer putFloats(logDiag)
- sign = 1.0
- for i := 0; i < n; i++ {
- v := lu.lu.at(i, i)
- if v < 0 {
- sign *= -1
- }
- if lu.pivot[i] != i {
- sign *= -1
- }
- logDiag[i] = math.Log(math.Abs(v))
- }
- return floats.Sum(logDiag), sign
- }
- // Pivot returns pivot indices that enable the construction of the permutation
- // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be
- // allocated, otherwise the length of the input must be equal to the size of the
- // factorized matrix.
- // Pivot will panic if the receiver does not contain a factorization.
- func (lu *LU) Pivot(swaps []int) []int {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- if swaps == nil {
- swaps = make([]int, n)
- }
- if len(swaps) != n {
- panic(badSliceLength)
- }
- // Perform the inverse of the row swaps in order to find the final
- // row swap position.
- for i := range swaps {
- swaps[i] = i
- }
- for i := n - 1; i >= 0; i-- {
- v := lu.pivot[i]
- swaps[i], swaps[v] = swaps[v], swaps[i]
- }
- return swaps
- }
- // RankOne updates an LU factorization as if a rank-one update had been applied to
- // the original matrix A, storing the result into the receiver. That is, if in
- // the original LU decomposition P * L * U = A, in the updated decomposition
- // P * L * U = A + alpha * x * yᵀ.
- // RankOne will panic if orig does not contain a factorization.
- func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) {
- if !orig.isValid() {
- panic(badLU)
- }
- // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix
- // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng.
- // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf
- _, n := orig.lu.Dims()
- if r, c := x.Dims(); r != n || c != 1 {
- panic(ErrShape)
- }
- if r, c := y.Dims(); r != n || c != 1 {
- panic(ErrShape)
- }
- if orig != lu {
- if lu.isZero() {
- if cap(lu.pivot) < n {
- lu.pivot = make([]int, n)
- }
- lu.pivot = lu.pivot[:n]
- if lu.lu == nil {
- lu.lu = NewDense(n, n, nil)
- } else {
- lu.lu.reuseAsNonZeroed(n, n)
- }
- } else if len(lu.pivot) != n {
- panic(ErrShape)
- }
- copy(lu.pivot, orig.pivot)
- lu.lu.Copy(orig.lu)
- }
- xs := getFloats(n, false)
- defer putFloats(xs)
- ys := getFloats(n, false)
- defer putFloats(ys)
- for i := 0; i < n; i++ {
- xs[i] = x.AtVec(i)
- ys[i] = y.AtVec(i)
- }
- // Adjust for the pivoting in the LU factorization
- for i, v := range lu.pivot {
- xs[i], xs[v] = xs[v], xs[i]
- }
- lum := lu.lu.mat
- omega := alpha
- for j := 0; j < n; j++ {
- ujj := lum.Data[j*lum.Stride+j]
- ys[j] /= ujj
- theta := 1 + xs[j]*ys[j]*omega
- beta := omega * ys[j] / theta
- gamma := omega * xs[j]
- omega -= beta * gamma
- lum.Data[j*lum.Stride+j] *= theta
- for i := j + 1; i < n; i++ {
- xs[i] -= lum.Data[i*lum.Stride+j] * xs[j]
- tmp := ys[i]
- ys[i] -= lum.Data[j*lum.Stride+i] * ys[j]
- lum.Data[i*lum.Stride+j] += beta * xs[i]
- lum.Data[j*lum.Stride+i] += gamma * tmp
- }
- }
- lu.updateCond(-1, CondNorm)
- }
- // LTo extracts the lower triangular matrix from an LU factorization.
- //
- // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix.
- // When dst is non-empty, LTo will panic if dst is not n×n or not Lower.
- // LTo will also panic if the receiver does not contain a successful
- // factorization.
- func (lu *LU) LTo(dst *TriDense) *TriDense {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- if dst.IsEmpty() {
- dst.ReuseAsTri(n, Lower)
- } else {
- n2, kind := dst.Triangle()
- if n != n2 {
- panic(ErrShape)
- }
- if kind != Lower {
- panic(ErrTriangle)
- }
- }
- // Extract the lower triangular elements.
- for i := 0; i < n; i++ {
- for j := 0; j < i; j++ {
- dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
- }
- }
- // Set ones on the diagonal.
- for i := 0; i < n; i++ {
- dst.mat.Data[i*dst.mat.Stride+i] = 1
- }
- return dst
- }
- // UTo extracts the upper triangular matrix from an LU factorization.
- //
- // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix.
- // When dst is non-empty, UTo will panic if dst is not n×n or not Upper.
- // UTo will also panic if the receiver does not contain a successful
- // factorization.
- func (lu *LU) UTo(dst *TriDense) {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- if dst.IsEmpty() {
- dst.ReuseAsTri(n, Upper)
- } else {
- n2, kind := dst.Triangle()
- if n != n2 {
- panic(ErrShape)
- }
- if kind != Upper {
- panic(ErrTriangle)
- }
- }
- // Extract the upper triangular elements.
- for i := 0; i < n; i++ {
- for j := i; j < n; j++ {
- dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
- }
- }
- }
- // Permutation constructs an r×r permutation matrix with the given row swaps.
- // A permutation matrix has exactly one element equal to one in each row and column
- // and all other elements equal to zero. swaps[i] specifies the row with which
- // i will be swapped, which is equivalent to the non-zero column of row i.
- func (m *Dense) Permutation(r int, swaps []int) {
- m.reuseAsNonZeroed(r, r)
- for i := 0; i < r; i++ {
- zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r])
- v := swaps[i]
- if v < 0 || v >= r {
- panic(ErrRowAccess)
- }
- m.mat.Data[i*m.mat.Stride+v] = 1
- }
- }
- // SolveTo solves a system of linear equations using the LU decomposition of a matrix.
- // It computes
- // A * X = B if trans == false
- // Aᵀ * X = B if trans == true
- // In both cases, A is represented in LU factorized form, and the matrix X is
- // stored into dst.
- //
- // If A is singular or near-singular a Condition error is returned. See
- // the documentation for Condition for more information.
- // SolveTo will panic if the receiver does not contain a factorization.
- func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- br, bc := b.Dims()
- if br != n {
- panic(ErrShape)
- }
- // TODO(btracey): Should test the condition number instead of testing that
- // the determinant is exactly zero.
- if lu.Det() == 0 {
- return Condition(math.Inf(1))
- }
- dst.reuseAsNonZeroed(n, bc)
- bU, _ := untranspose(b)
- var restore func()
- if dst == bU {
- dst, restore = dst.isolatedWorkspace(bU)
- defer restore()
- } else if rm, ok := bU.(RawMatrixer); ok {
- dst.checkOverlap(rm.RawMatrix())
- }
- dst.Copy(b)
- t := blas.NoTrans
- if trans {
- t = blas.Trans
- }
- lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot)
- if lu.cond > ConditionTolerance {
- return Condition(lu.cond)
- }
- return nil
- }
- // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix.
- // It computes
- // A * x = b if trans == false
- // Aᵀ * x = b if trans == true
- // In both cases, A is represented in LU factorized form, and the vector x is
- // stored into dst.
- //
- // If A is singular or near-singular a Condition error is returned. See
- // the documentation for Condition for more information.
- // SolveVecTo will panic if the receiver does not contain a factorization.
- func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
- if !lu.isValid() {
- panic(badLU)
- }
- _, n := lu.lu.Dims()
- if br, bc := b.Dims(); br != n || bc != 1 {
- panic(ErrShape)
- }
- switch rv := b.(type) {
- default:
- dst.reuseAsNonZeroed(n)
- return lu.SolveTo(dst.asDense(), trans, b)
- case RawVectorer:
- if dst != b {
- dst.checkOverlap(rv.RawVector())
- }
- // TODO(btracey): Should test the condition number instead of testing that
- // the determinant is exactly zero.
- if lu.Det() == 0 {
- return Condition(math.Inf(1))
- }
- dst.reuseAsNonZeroed(n)
- var restore func()
- if dst == b {
- dst, restore = dst.isolatedWorkspace(b)
- defer restore()
- }
- dst.CopyVec(b)
- vMat := blas64.General{
- Rows: n,
- Cols: 1,
- Stride: dst.mat.Inc,
- Data: dst.mat.Data,
- }
- t := blas.NoTrans
- if trans {
- t = blas.Trans
- }
- lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
- if lu.cond > ConditionTolerance {
- return Condition(lu.cond)
- }
- return nil
- }
- }
|