// Copyright ©2013 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mat import ( "math" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/lapack" "gonum.org/v1/gonum/lapack/lapack64" ) const ( badSliceLength = "mat: improper slice length" badLU = "mat: invalid LU factorization" ) // LU is a type for creating and using the LU factorization of a matrix. type LU struct { lu *Dense pivot []int cond float64 } // updateCond updates the stored condition number of the matrix. anorm is the // norm of the original matrix. If anorm is negative it will be estimated. func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) { n := lu.lu.mat.Cols work := getFloats(4*n, false) defer putFloats(work) iwork := getInts(n, false) defer putInts(iwork) if anorm < 0 { // This is an approximation. By the definition of a norm, // |AB| <= |A| |B|. // Since A = L*U, we get for the condition number κ that // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|, // so this will overestimate the condition number somewhat. // The norm of the original factorized matrix cannot be stored // because of update possibilities. u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) unorm := lapack64.Lantr(norm, u.mat, work) lnorm := lapack64.Lantr(norm, l.mat, work) anorm = unorm * lnorm } v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork) lu.cond = 1 / v } // Factorize computes the LU factorization of the square matrix a and stores the // result. The LU decomposition will complete regardless of the singularity of a. // // The LU factorization is computed with pivoting, and so really the decomposition // is a PLU decomposition where P is a permutation matrix. The individual matrix // factors can be extracted from the factorization using the Permutation method // on Dense, and the LU.LTo and LU.UTo methods. func (lu *LU) Factorize(a Matrix) { lu.factorize(a, CondNorm) } func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) { r, c := a.Dims() if r != c { panic(ErrSquare) } if lu.lu == nil { lu.lu = NewDense(r, r, nil) } else { lu.lu.Reset() lu.lu.reuseAsNonZeroed(r, r) } lu.lu.Copy(a) if cap(lu.pivot) < r { lu.pivot = make([]int, r) } lu.pivot = lu.pivot[:r] work := getFloats(r, false) anorm := lapack64.Lange(norm, lu.lu.mat, work) putFloats(work) lapack64.Getrf(lu.lu.mat, lu.pivot) lu.updateCond(anorm, norm) } // isValid returns whether the receiver contains a factorization. func (lu *LU) isValid() bool { return lu.lu != nil && !lu.lu.IsEmpty() } // Cond returns the condition number for the factorized matrix. // Cond will panic if the receiver does not contain a factorization. func (lu *LU) Cond() float64 { if !lu.isValid() { panic(badLU) } return lu.cond } // Reset resets the factorization so that it can be reused as the receiver of a // dimensionally restricted operation. func (lu *LU) Reset() { if lu.lu != nil { lu.lu.Reset() } lu.pivot = lu.pivot[:0] } func (lu *LU) isZero() bool { return len(lu.pivot) == 0 } // Det returns the determinant of the matrix that has been factorized. In many // expressions, using LogDet will be more numerically stable. // Det will panic if the receiver does not contain a factorization. func (lu *LU) Det() float64 { det, sign := lu.LogDet() return math.Exp(det) * sign } // LogDet returns the log of the determinant and the sign of the determinant // for the matrix that has been factorized. Numerical stability in product and // division expressions is generally improved by working in log space. // LogDet will panic if the receiver does not contain a factorization. func (lu *LU) LogDet() (det float64, sign float64) { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() logDiag := getFloats(n, false) defer putFloats(logDiag) sign = 1.0 for i := 0; i < n; i++ { v := lu.lu.at(i, i) if v < 0 { sign *= -1 } if lu.pivot[i] != i { sign *= -1 } logDiag[i] = math.Log(math.Abs(v)) } return floats.Sum(logDiag), sign } // Pivot returns pivot indices that enable the construction of the permutation // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be // allocated, otherwise the length of the input must be equal to the size of the // factorized matrix. // Pivot will panic if the receiver does not contain a factorization. func (lu *LU) Pivot(swaps []int) []int { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() if swaps == nil { swaps = make([]int, n) } if len(swaps) != n { panic(badSliceLength) } // Perform the inverse of the row swaps in order to find the final // row swap position. for i := range swaps { swaps[i] = i } for i := n - 1; i >= 0; i-- { v := lu.pivot[i] swaps[i], swaps[v] = swaps[v], swaps[i] } return swaps } // RankOne updates an LU factorization as if a rank-one update had been applied to // the original matrix A, storing the result into the receiver. That is, if in // the original LU decomposition P * L * U = A, in the updated decomposition // P * L * U = A + alpha * x * yᵀ. // RankOne will panic if orig does not contain a factorization. func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) { if !orig.isValid() { panic(badLU) } // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf _, n := orig.lu.Dims() if r, c := x.Dims(); r != n || c != 1 { panic(ErrShape) } if r, c := y.Dims(); r != n || c != 1 { panic(ErrShape) } if orig != lu { if lu.isZero() { if cap(lu.pivot) < n { lu.pivot = make([]int, n) } lu.pivot = lu.pivot[:n] if lu.lu == nil { lu.lu = NewDense(n, n, nil) } else { lu.lu.reuseAsNonZeroed(n, n) } } else if len(lu.pivot) != n { panic(ErrShape) } copy(lu.pivot, orig.pivot) lu.lu.Copy(orig.lu) } xs := getFloats(n, false) defer putFloats(xs) ys := getFloats(n, false) defer putFloats(ys) for i := 0; i < n; i++ { xs[i] = x.AtVec(i) ys[i] = y.AtVec(i) } // Adjust for the pivoting in the LU factorization for i, v := range lu.pivot { xs[i], xs[v] = xs[v], xs[i] } lum := lu.lu.mat omega := alpha for j := 0; j < n; j++ { ujj := lum.Data[j*lum.Stride+j] ys[j] /= ujj theta := 1 + xs[j]*ys[j]*omega beta := omega * ys[j] / theta gamma := omega * xs[j] omega -= beta * gamma lum.Data[j*lum.Stride+j] *= theta for i := j + 1; i < n; i++ { xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] tmp := ys[i] ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] lum.Data[i*lum.Stride+j] += beta * xs[i] lum.Data[j*lum.Stride+i] += gamma * tmp } } lu.updateCond(-1, CondNorm) } // LTo extracts the lower triangular matrix from an LU factorization. // // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix. // When dst is non-empty, LTo will panic if dst is not n×n or not Lower. // LTo will also panic if the receiver does not contain a successful // factorization. func (lu *LU) LTo(dst *TriDense) *TriDense { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() if dst.IsEmpty() { dst.ReuseAsTri(n, Lower) } else { n2, kind := dst.Triangle() if n != n2 { panic(ErrShape) } if kind != Lower { panic(ErrTriangle) } } // Extract the lower triangular elements. for i := 0; i < n; i++ { for j := 0; j < i; j++ { dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] } } // Set ones on the diagonal. for i := 0; i < n; i++ { dst.mat.Data[i*dst.mat.Stride+i] = 1 } return dst } // UTo extracts the upper triangular matrix from an LU factorization. // // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix. // When dst is non-empty, UTo will panic if dst is not n×n or not Upper. // UTo will also panic if the receiver does not contain a successful // factorization. func (lu *LU) UTo(dst *TriDense) { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() if dst.IsEmpty() { dst.ReuseAsTri(n, Upper) } else { n2, kind := dst.Triangle() if n != n2 { panic(ErrShape) } if kind != Upper { panic(ErrTriangle) } } // Extract the upper triangular elements. for i := 0; i < n; i++ { for j := i; j < n; j++ { dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] } } } // Permutation constructs an r×r permutation matrix with the given row swaps. // A permutation matrix has exactly one element equal to one in each row and column // and all other elements equal to zero. swaps[i] specifies the row with which // i will be swapped, which is equivalent to the non-zero column of row i. func (m *Dense) Permutation(r int, swaps []int) { m.reuseAsNonZeroed(r, r) for i := 0; i < r; i++ { zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r]) v := swaps[i] if v < 0 || v >= r { panic(ErrRowAccess) } m.mat.Data[i*m.mat.Stride+v] = 1 } } // SolveTo solves a system of linear equations using the LU decomposition of a matrix. // It computes // A * X = B if trans == false // Aᵀ * X = B if trans == true // In both cases, A is represented in LU factorized form, and the matrix X is // stored into dst. // // If A is singular or near-singular a Condition error is returned. See // the documentation for Condition for more information. // SolveTo will panic if the receiver does not contain a factorization. func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() br, bc := b.Dims() if br != n { panic(ErrShape) } // TODO(btracey): Should test the condition number instead of testing that // the determinant is exactly zero. if lu.Det() == 0 { return Condition(math.Inf(1)) } dst.reuseAsNonZeroed(n, bc) bU, _ := untranspose(b) var restore func() if dst == bU { dst, restore = dst.isolatedWorkspace(bU) defer restore() } else if rm, ok := bU.(RawMatrixer); ok { dst.checkOverlap(rm.RawMatrix()) } dst.Copy(b) t := blas.NoTrans if trans { t = blas.Trans } lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot) if lu.cond > ConditionTolerance { return Condition(lu.cond) } return nil } // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix. // It computes // A * x = b if trans == false // Aᵀ * x = b if trans == true // In both cases, A is represented in LU factorized form, and the vector x is // stored into dst. // // If A is singular or near-singular a Condition error is returned. See // the documentation for Condition for more information. // SolveVecTo will panic if the receiver does not contain a factorization. func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error { if !lu.isValid() { panic(badLU) } _, n := lu.lu.Dims() if br, bc := b.Dims(); br != n || bc != 1 { panic(ErrShape) } switch rv := b.(type) { default: dst.reuseAsNonZeroed(n) return lu.SolveTo(dst.asDense(), trans, b) case RawVectorer: if dst != b { dst.checkOverlap(rv.RawVector()) } // TODO(btracey): Should test the condition number instead of testing that // the determinant is exactly zero. if lu.Det() == 0 { return Condition(math.Inf(1)) } dst.reuseAsNonZeroed(n) var restore func() if dst == b { dst, restore = dst.isolatedWorkspace(b) defer restore() } dst.CopyVec(b) vMat := blas64.General{ Rows: n, Cols: 1, Stride: dst.mat.Inc, Data: dst.mat.Data, } t := blas.NoTrans if trans { t = blas.Trans } lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) if lu.cond > ConditionTolerance { return Condition(lu.cond) } return nil } }