1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00
golearn/optimisation/gradientdescent_test.go

33 lines
899 B
Go
Raw Normal View History

2014-05-05 02:39:00 +01:00
package optimisation
import (
"testing"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64"
)
func init() {
mat64.Register(cblas.Blas{})
}
func TestBGD(t *testing.T) {
x := mat64.NewDense(2, 3, []float64{1, 1, 2, 1, 2, 3})
y := mat64.NewDense(2, 1, []float64{3, 4})
theta := mat64.NewDense(3, 1, []float64{1, 1, 1})
results := BatchGradientDescent(x, y, theta, 0.0001, 10000)
if results.At(0, 0) < 0.880 || results.At(0, 0) > 0.881 {
t.Error("Innaccurate convergence of batch gradient descent")
}
}
func TestSGD(t *testing.T) {
x := mat64.NewDense(2, 3, []float64{1, 1, 2, 1, 2, 3})
y := mat64.NewDense(2, 1, []float64{3, 4})
theta := mat64.NewDense(3, 1, []float64{1, 1, 1})
results := StochasticGradientDescent(x, y, theta, 0.0001, 10000)
if results.At(0, 0) < 0.880 || results.At(0, 0) > 0.881 {
t.Error("Innaccurate convergence of batch gradient descent")
}
}