mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-26 13:48:53 +08:00

Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com> Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com>
462 lines
12 KiB
Go
462 lines
12 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package gonum
|
||
|
||
import (
|
||
"gonum.org/v1/gonum/blas"
|
||
"gonum.org/v1/gonum/blas/blas64"
|
||
"gonum.org/v1/gonum/lapack"
|
||
)
|
||
|
||
// Dlarfb applies a block reflector to a matrix.
|
||
//
|
||
// In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
|
||
//
|
||
// c = h * c if side == Left and trans == NoTrans
|
||
// c = c * h if side == Right and trans == NoTrans
|
||
// c = hᵀ * c if side == Left and trans == Trans
|
||
// c = c * hᵀ if side == Right and trans == Trans
|
||
//
|
||
// h is a product of elementary reflectors. direct sets the direction of multiplication
|
||
//
|
||
// h = h_1 * h_2 * ... * h_k if direct == Forward
|
||
// h = h_k * h_k-1 * ... * h_1 if direct == Backward
|
||
//
|
||
// The combination of direct and store defines the orientation of the elementary
|
||
// reflectors. In all cases the ones on the diagonal are implicitly represented.
|
||
//
|
||
// If direct == lapack.Forward and store == lapack.ColumnWise
|
||
//
|
||
// V = [ 1 ]
|
||
// [v1 1 ]
|
||
// [v1 v2 1]
|
||
// [v1 v2 v3]
|
||
// [v1 v2 v3]
|
||
//
|
||
// If direct == lapack.Forward and store == lapack.RowWise
|
||
//
|
||
// V = [ 1 v1 v1 v1 v1]
|
||
// [ 1 v2 v2 v2]
|
||
// [ 1 v3 v3]
|
||
//
|
||
// If direct == lapack.Backward and store == lapack.ColumnWise
|
||
//
|
||
// V = [v1 v2 v3]
|
||
// [v1 v2 v3]
|
||
// [ 1 v2 v3]
|
||
// [ 1 v3]
|
||
// [ 1]
|
||
//
|
||
// If direct == lapack.Backward and store == lapack.RowWise
|
||
//
|
||
// V = [v1 v1 1 ]
|
||
// [v2 v2 v2 1 ]
|
||
// [v3 v3 v3 v3 1]
|
||
//
|
||
// An elementary reflector can be explicitly constructed by extracting the
|
||
// corresponding elements of v, placing a 1 where the diagonal would be, and
|
||
// placing zeros in the remaining elements.
|
||
//
|
||
// t is a k×k matrix containing the block reflector, and this function will panic
|
||
// if t is not of sufficient size. See Dlarft for more information.
|
||
//
|
||
// work is a temporary storage matrix with stride ldwork.
|
||
// work must be of size at least n×k side == Left and m×k if side == Right, and
|
||
// this function will panic if this size is not met.
|
||
//
|
||
// Dlarfb is an internal routine. It is exported for testing purposes.
|
||
func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
|
||
nv := m
|
||
if side == blas.Right {
|
||
nv = n
|
||
}
|
||
switch {
|
||
case side != blas.Left && side != blas.Right:
|
||
panic(badSide)
|
||
case trans != blas.Trans && trans != blas.NoTrans:
|
||
panic(badTrans)
|
||
case direct != lapack.Forward && direct != lapack.Backward:
|
||
panic(badDirect)
|
||
case store != lapack.ColumnWise && store != lapack.RowWise:
|
||
panic(badStoreV)
|
||
case m < 0:
|
||
panic(mLT0)
|
||
case n < 0:
|
||
panic(nLT0)
|
||
case k < 0:
|
||
panic(kLT0)
|
||
case store == lapack.ColumnWise && ldv < max(1, k):
|
||
panic(badLdV)
|
||
case store == lapack.RowWise && ldv < max(1, nv):
|
||
panic(badLdV)
|
||
case ldt < max(1, k):
|
||
panic(badLdT)
|
||
case ldc < max(1, n):
|
||
panic(badLdC)
|
||
case ldwork < max(1, k):
|
||
panic(badLdWork)
|
||
}
|
||
|
||
if m == 0 || n == 0 {
|
||
return
|
||
}
|
||
|
||
nw := n
|
||
if side == blas.Right {
|
||
nw = m
|
||
}
|
||
switch {
|
||
case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k:
|
||
panic(shortV)
|
||
case store == lapack.RowWise && len(v) < (k-1)*ldv+nv:
|
||
panic(shortV)
|
||
case len(t) < (k-1)*ldt+k:
|
||
panic(shortT)
|
||
case len(c) < (m-1)*ldc+n:
|
||
panic(shortC)
|
||
case len(work) < (nw-1)*ldwork+k:
|
||
panic(shortWork)
|
||
}
|
||
|
||
bi := blas64.Implementation()
|
||
|
||
transt := blas.Trans
|
||
if trans == blas.Trans {
|
||
transt = blas.NoTrans
|
||
}
|
||
// TODO(btracey): This follows the original Lapack code where the
|
||
// elements are copied into the columns of the working array. The
|
||
// loops should go in the other direction so the data is written
|
||
// into the rows of work so the copy is not strided. A bigger change
|
||
// would be to replace work with workᵀ, but benchmarks would be
|
||
// needed to see if the change is merited.
|
||
if store == lapack.ColumnWise {
|
||
if direct == lapack.Forward {
|
||
// V1 is the first k rows of C. V2 is the remaining rows.
|
||
if side == blas.Left {
|
||
// W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work).
|
||
|
||
// W = C1.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W = W * V1.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
|
||
n, k, 1,
|
||
v, ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
// W = W + C2ᵀ V2.
|
||
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||
1, c[k*ldc:], ldc, v[k*ldv:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W = W * Tᵀ or W * T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V * Wᵀ.
|
||
if m > k {
|
||
// C2 -= V2 * Wᵀ.
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||
-1, v[k*ldv:], ldv, work, ldwork,
|
||
1, c[k*ldc:], ldc)
|
||
}
|
||
// W *= V1ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= Wᵀ.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[j*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C = C * H or C * Hᵀ, where C = (C1 C2).
|
||
|
||
// W = C1.
|
||
for i := 0; i < k; i++ {
|
||
bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||
1, c[k:], ldc, v[k*ldv:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||
-1, work, ldwork, v[k*ldv:], ldv,
|
||
1, c[k:], ldc)
|
||
}
|
||
// C -= W * Vᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// V = (V1)
|
||
// = (V2) (last k rows)
|
||
// Where V2 is unit upper triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or
|
||
// W = Cᵀ V.
|
||
|
||
// W = C2ᵀ.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||
1, v[(m-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
// W += C1ᵀ * V1.
|
||
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V * Wᵀ.
|
||
if m > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||
-1, v, ldv, work, ldwork,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||
1, v[(m-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
// C2 -= Wᵀ.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * Hᵀ where C = (C1 C2).
|
||
// W = C * V.
|
||
|
||
// W = C2.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||
}
|
||
|
||
// W = W * V2.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||
1, v[(n-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * Vᵀ.
|
||
if n > k {
|
||
// C1 -= W * V1ᵀ.
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||
-1, work, ldwork, v, ldv,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||
1, v[(n-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
// C2 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Store = Rowwise.
|
||
if direct == lapack.Forward {
|
||
// V = (V1 V2) where v1 is unit upper triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or Hᵀ * C where C = (C1; C2).
|
||
// W = Cᵀ * Vᵀ.
|
||
|
||
// W = C1ᵀ.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V1ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||
1, c[k*ldc:], ldc, v[k:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= Vᵀ * Wᵀ.
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||
-1, v[k:], ldv, work, ldwork,
|
||
1, c[k*ldc:], ldc)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= Wᵀ.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[j*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * Hᵀ where C = (C1 C2).
|
||
// W = C * Vᵀ.
|
||
|
||
// W = C1.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
|
||
}
|
||
// W *= V1ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||
1, c[k:], ldc, v[k:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * V.
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||
-1, work, ldwork, v[k:], ldv,
|
||
1, c[k:], ldc)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or Hᵀ C where C = (C1 ; C2).
|
||
// W = Cᵀ * Vᵀ.
|
||
|
||
// W = C2ᵀ.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V2ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||
1, v[m-k:], ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= Vᵀ * Wᵀ.
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||
-1, v, ldv, work, ldwork,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
|
||
1, v[m-k:], ldv,
|
||
work, ldwork)
|
||
// C2 -= Wᵀ.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * Hᵀ where C = (C1 C2).
|
||
// W = C * Vᵀ.
|
||
// W = C2.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||
}
|
||
// W *= V2ᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||
1, v[n-k:], ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or Tᵀ.
|
||
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * V.
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||
-1, work, ldwork, v, ldv,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||
1, v[n-k:], ldv,
|
||
work, ldwork)
|
||
// C1 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
}
|