mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Made a start on gradient descent
This commit is contained in:
parent
b4abf54c07
commit
dc96e818d8
@ -2,6 +2,12 @@
|
||||
// It also provides a raw base for those objects.
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
import (
|
||||
mat64 "github.com/gonum/matrix/mat64"
|
||||
)
|
||||
@ -22,8 +28,21 @@ type Model interface {
|
||||
Score()
|
||||
}
|
||||
|
||||
// @todo: Implement BaseEstimator setters and getters.
|
||||
type BaseEstimator struct {
|
||||
Estimator
|
||||
Data *mat64.Dense
|
||||
}
|
||||
|
||||
// Serialises an estimator to a provided filepath, in gob format.
|
||||
// See http://golang.org/pkg/encoding/gob for further details.
|
||||
func SaveEstimatorToGob(path string, e *Estimator) {
|
||||
b := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(b)
|
||||
err := enc.Encode(e)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = ioutil.WriteFile(path, b.Bytes(), 0644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
69
optimisation/gradient_descent.go
Normal file
69
optimisation/gradient_descent.go
Normal file
@ -0,0 +1,69 @@
|
||||
package optimisation
|
||||
|
||||
import "github.com/gonum/matrix/mat64"
|
||||
|
||||
// Batch gradient descent finds the local minimum of a function.
|
||||
// See http://en.wikipedia.org/wiki/Gradient_descent for more details.
|
||||
func BatchGradientDescent(x, y, theta *mat64.Dense, alpha float64, epoch int) *mat64.Dense {
|
||||
m, _ := y.Dims()
|
||||
for i := 0; i < epoch; i++ {
|
||||
xFlat := mat64.DenseCopyOf(x)
|
||||
xFlat.TCopy(xFlat)
|
||||
temp := mat64.DenseCopyOf(x)
|
||||
temp.Mul(temp, theta)
|
||||
temp.Sub(temp, y)
|
||||
xFlat.Mul(xFlat, temp)
|
||||
|
||||
// Horrible hack to get around the fact there is no scalar division in mat64
|
||||
xFlatRow, _ := xFlat.Dims()
|
||||
gradient := make([]float64, 0)
|
||||
for i := 0; i < xFlatRow; i++ {
|
||||
row := xFlat.RowView(i)
|
||||
for i := range row {
|
||||
divd := row[i] / float64(m) * alpha
|
||||
gradient = append(gradient, divd)
|
||||
}
|
||||
}
|
||||
grows := len(gradient)
|
||||
grad := mat64.NewDense(grows, 1, gradient)
|
||||
theta.Sub(theta, grad)
|
||||
}
|
||||
return theta
|
||||
}
|
||||
|
||||
// Stochastic Gradient Descent updates the parameters of theta on a random selection from X,Y.
|
||||
// It is faster as it does not compute the cost function over the entire dataset every time.
|
||||
// In return, there is a trade off for accuracy.
|
||||
// @todo: use goroutines to parallelise training.
|
||||
func StochasticGradientDescent(x, y, theta *mat64.Dense, alpha float64, epoch int) *mat64.Dense {
|
||||
m, _ := y.Dims()
|
||||
for i := 0; i < epoch; i++ {
|
||||
for k := 0; k < m; k++ {
|
||||
datXtemp := x.RowView(k)
|
||||
datYtemp := y.RowView(k)
|
||||
datX := mat64.NewDense(1, len(datXtemp), datXtemp)
|
||||
datY := mat64.NewDense(1, 1, datYtemp)
|
||||
datXFlat := mat64.DenseCopyOf(datX)
|
||||
datXFlat.TCopy(datXFlat)
|
||||
datX.Mul(datX, theta)
|
||||
datX.Sub(datX, datY)
|
||||
datXFlat.Mul(datXFlat, datX)
|
||||
|
||||
// Horrible hack to get around the fact there is no scalar division in mat64
|
||||
xFlatRow, _ := datXFlat.Dims()
|
||||
gradient := make([]float64, 0)
|
||||
for i := 0; i < xFlatRow; i++ {
|
||||
row := datXFlat.RowView(i)
|
||||
for i := range row {
|
||||
divd := row[i] / float64(m) * alpha
|
||||
gradient = append(gradient, divd)
|
||||
}
|
||||
}
|
||||
grows := len(gradient)
|
||||
grad := mat64.NewDense(grows, 1, gradient)
|
||||
theta.Sub(theta, grad)
|
||||
}
|
||||
|
||||
}
|
||||
return theta
|
||||
}
|
32
optimisation/gradientdescent_test.go
Normal file
32
optimisation/gradientdescent_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package optimisation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gonum/blas/cblas"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mat64.Register(cblas.Blas{})
|
||||
}
|
||||
|
||||
func TestBGD(t *testing.T) {
|
||||
x := mat64.NewDense(2, 3, []float64{1, 1, 2, 1, 2, 3})
|
||||
y := mat64.NewDense(2, 1, []float64{3, 4})
|
||||
theta := mat64.NewDense(3, 1, []float64{1, 1, 1})
|
||||
results := BatchGradientDescent(x, y, theta, 0.0001, 10000)
|
||||
if results.At(0, 0) < 0.880 || results.At(0, 0) > 0.881 {
|
||||
t.Error("Innaccurate convergence of batch gradient descent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSGD(t *testing.T) {
|
||||
x := mat64.NewDense(2, 3, []float64{1, 1, 2, 1, 2, 3})
|
||||
y := mat64.NewDense(2, 1, []float64{3, 4})
|
||||
theta := mat64.NewDense(3, 1, []float64{1, 1, 1})
|
||||
results := StochasticGradientDescent(x, y, theta, 0.0001, 10000)
|
||||
if results.At(0, 0) < 0.880 || results.At(0, 0) > 0.881 {
|
||||
t.Error("Innaccurate convergence of batch gradient descent")
|
||||
}
|
||||
}
|
2
optimisation/optimisation.go
Normal file
2
optimisation/optimisation.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package optimisation provides a number of optimisation functions.
|
||||
package optimisation
|
Loading…
x
Reference in New Issue
Block a user