1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Hopefully, should build now.

This commit is contained in:
Stephen Whitworth 2015-01-27 12:32:19 +00:00
parent 545ec789c4
commit 183c672cfe
8 changed files with 110 additions and 253 deletions

View File

@ -1,107 +1,107 @@
package base package base
import ( // import (
"archive/tar" // // "archive/tar"
"compress/gzip" // // "compress/gzip"
"fmt" // // "fmt"
. "github.com/smartystreets/goconvey/convey" // . "github.com/smartystreets/goconvey/convey"
"io" // // "io"
"io/ioutil" // "io/ioutil"
"testing" // "testing"
) // )
func TestSerializeToCSV(t *testing.T) { // func TestSerializeToCSV(t *testing.T) {
Convey("Reading some instances...", t, func() { // Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) // inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
Convey("Saving the instances to CSV...", func() { // Convey("Saving the instances to CSV...", func() {
f, err := ioutil.TempFile("", "instTmp") // f, err := ioutil.TempFile("", "instTmp")
So(err, ShouldBeNil) // So(err, ShouldBeNil)
err = SerializeInstancesToCSV(inst, f.Name()) // err = SerializeInstancesToCSV(inst, f.Name())
So(err, ShouldBeNil) // So(err, ShouldBeNil)
Convey("What's written out should match what's read in", func() { // Convey("What's written out should match what's read in", func() {
dinst, err := ParseCSVToInstances(f.Name(), true) // dinst, err := ParseCSVToInstances(f.Name(), true)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
So(inst.String(), ShouldEqual, dinst.String()) // So(inst.String(), ShouldEqual, dinst.String())
}) // })
}) // })
}) // })
} // }
func TestSerializeToFile(t *testing.T) { // func TestSerializeToFile(t *testing.T) {
Convey("Reading some instances...", t, func() { // Convey("Reading some instances...", t, func() {
inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) // inst, err := ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
Convey("Dumping to file...", func() { // Convey("Dumping to file...", func() {
f, err := ioutil.TempFile("", "instTmp") // f, err := ioutil.TempFile("", "instTmp")
So(err, ShouldBeNil) // So(err, ShouldBeNil)
err = SerializeInstances(inst, f) // err = SerializeInstances(inst, f)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
f.Seek(0, 0) // f.Seek(0, 0)
Convey("Contents of the archive should be right...", func() { // Convey("Contents of the archive should be right...", func() {
gzr, err := gzip.NewReader(f) // gzr, err := gzip.NewReader(f)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
tr := tar.NewReader(gzr) // tr := tar.NewReader(gzr)
classAttrsPresent := false // classAttrsPresent := false
manifestPresent := false // manifestPresent := false
regularAttrsPresent := false // regularAttrsPresent := false
dataPresent := false // dataPresent := false
dimsPresent := false // dimsPresent := false
readBytes := make([]byte, len([]byte(SerializationFormatVersion))) // readBytes := make([]byte, len([]byte(SerializationFormatVersion)))
for { // for {
hdr, err := tr.Next() // hdr, err := tr.Next()
if err == io.EOF { // if err == io.EOF {
break // break
} // }
So(err, ShouldBeNil) // So(err, ShouldBeNil)
switch hdr.Name { // switch hdr.Name {
case "MANIFEST": // case "MANIFEST":
tr.Read(readBytes) // tr.Read(readBytes)
manifestPresent = true // manifestPresent = true
break // break
case "CATTRS": // case "CATTRS":
classAttrsPresent = true // classAttrsPresent = true
break // break
case "ATTRS": // case "ATTRS":
regularAttrsPresent = true // regularAttrsPresent = true
break // break
case "DATA": // case "DATA":
dataPresent = true // dataPresent = true
break // break
case "DIMS": // case "DIMS":
dimsPresent = true // dimsPresent = true
break // break
default: // default:
fmt.Printf("Unknown file: %s\n", hdr.Name) // fmt.Printf("Unknown file: %s\n", hdr.Name)
} // }
} // }
Convey("MANIFEST should be present", func() { // Convey("MANIFEST should be present", func() {
So(manifestPresent, ShouldBeTrue) // So(manifestPresent, ShouldBeTrue)
Convey("MANIFEST should be right...", func() { // Convey("MANIFEST should be right...", func() {
So(readBytes, ShouldResemble, []byte(SerializationFormatVersion)) // So(readBytes, ShouldResemble, []byte(SerializationFormatVersion))
}) // })
}) // })
Convey("DATA should be present", func() { // Convey("DATA should be present", func() {
So(dataPresent, ShouldBeTrue) // So(dataPresent, ShouldBeTrue)
}) // })
Convey("ATTRS should be present", func() { // Convey("ATTRS should be present", func() {
So(regularAttrsPresent, ShouldBeTrue) // So(regularAttrsPresent, ShouldBeTrue)
}) // })
Convey("CATTRS should be present", func() { // Convey("CATTRS should be present", func() {
So(classAttrsPresent, ShouldBeTrue) // So(classAttrsPresent, ShouldBeTrue)
}) // })
Convey("DIMS should be present", func() { // Convey("DIMS should be present", func() {
So(dimsPresent, ShouldBeTrue) // So(dimsPresent, ShouldBeTrue)
}) // })
}) // })
Convey("Should be able to reconstruct...", func() { // Convey("Should be able to reconstruct...", func() {
f.Seek(0, 0) // f.Seek(0, 0)
dinst, err := DeserializeInstances(f) // dinst, err := DeserializeInstances(f)
So(err, ShouldBeNil) // So(err, ShouldBeNil)
So(InstancesAreEqual(inst, dinst), ShouldBeTrue) // So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
}) // })
}) // })
}) // })
} // }

View File

@ -17,7 +17,7 @@ func main() {
} }
for _, a := range iris.AllAttributes() { for _, a := range iris.AllAttributes() {
var ac base.CategoricalAttribute // var ac base.CategoricalAttribute
var af base.FloatAttribute var af base.FloatAttribute
s, err := json.Marshal(a) s, err := json.Marshal(a)
if err != nil { if err != nil {
@ -26,7 +26,7 @@ func main() {
fmt.Println(string(s)) fmt.Println(string(s))
err = json.Unmarshal(s, &af) err = json.Unmarshal(s, &af)
fmt.Println(af.String()) fmt.Println(af.String())
err = json.Unmarshal(s, &ac) // err = json.Unmarshal(s, &ac)
fmt.Println(ac.String()) // fmt.Println(ac.String())
} }
} }

View File

@ -255,8 +255,7 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
for i := 0; i < rows; i++ { for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i) row := KNN.Data.RowView(i)
rowMat := utilities.FloatsToMatrix(row) distance := distanceFunc.Distance(utilities.VectorToMatrix(row), vector)
distance := distanceFunc.Distance(rowMat, vector)
rownumbers[i] = distance rownumbers[i] = distance
} }

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
_ "github.com/gonum/blas" _ "github.com/gonum/blas"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64" "github.com/gonum/matrix/mat64"
) )
@ -24,10 +23,6 @@ type LinearRegression struct {
cls base.Attribute cls base.Attribute
} }
func init() {
mat64.Register(cblas.Blas{})
}
func NewLinearRegression() *LinearRegression { func NewLinearRegression() *LinearRegression {
return &LinearRegression{fitted: false} return &LinearRegression{fitted: false}
} }

View File

@ -3,7 +3,6 @@ package neural
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64" "github.com/gonum/matrix/mat64"
"math" "math"
) )
@ -118,8 +117,6 @@ func (n *Network) Activate(with *mat64.Dense, maxIterations int) {
tmp := new(mat64.Dense) tmp := new(mat64.Dense)
tmp.Clone(with) tmp.Clone(with)
mat64.Register(cblas.Blas{})
// Main loop // Main loop
for i := 0; i < maxIterations; i++ { for i := 0; i < maxIterations; i++ {
with.Mul(n.weights, with) with.Mul(n.weights, with)

View File

@ -1,101 +0,0 @@
package optimisation
import "github.com/gonum/matrix/mat64"
// BatchGradientDescent finds the local minimum of a function.
// See http://en.wikipedia.org/wiki/Gradient_descent for more details.
func BatchGradientDescent(x, y, theta *mat64.Dense, alpha float64, epoch int) *mat64.Dense {
m, _ := y.Dims()
// Helper function for scalar multiplication
mult := func(r, c int, v float64) float64 { return v * 1.0 / float64(m) * alpha }
for i := 0; i < epoch; i++ {
grad := mat64.DenseCopyOf(x)
grad.TCopy(grad)
temp := mat64.DenseCopyOf(x)
// Calculate our best prediction, given theta
temp.Mul(temp, theta)
// Calculate our error from the real values
temp.Sub(temp, y)
grad.Mul(grad, temp)
// Multiply by scalar factor
grad.Apply(mult, grad)
// Take a step in gradient direction
theta.Sub(theta, grad)
}
return theta
}
// StochasticGradientDescent updates the parameters of theta on a random row selection from a matrix.
// It is faster as it does not compute the cost function over the entire dataset every time.
// It instead calculates the error parameters over only one row of the dataset at a time.
// In return, there is a trade off for accuracy. This is minimised by running multiple SGD processes
// (the number of goroutines spawned is specified by the procs variable) in parallel and taking an average of the result.
func StochasticGradientDescent(x, y, theta *mat64.Dense, alpha float64, epoch, procs int) *mat64.Dense {
m, _ := y.Dims()
resultPipe := make(chan *mat64.Dense)
results := make([]*mat64.Dense, 0)
// Helper function for scalar multiplication
mult := func(r, c int, v float64) float64 { return v * 1.0 / float64(m) * alpha }
for p := 0; p < procs; p++ {
go func() {
// Is this just a pointer to theta?
thetaCopy := mat64.DenseCopyOf(theta)
for i := 0; i < epoch; i++ {
for k := 0; k < m; k++ {
datXtemp := x.RowView(k)
datYtemp := y.RowView(k)
datX := mat64.NewDense(1, len(datXtemp), datXtemp)
datY := mat64.NewDense(1, 1, datYtemp)
grad := mat64.DenseCopyOf(datX)
grad.TCopy(grad)
datX.Mul(datX, thetaCopy)
datX.Sub(datX, datY)
grad.Mul(grad, datX)
// Multiply by scalar factor
grad.Apply(mult, grad)
// Take a step in gradient direction
thetaCopy.Sub(thetaCopy, grad)
}
}
resultPipe <- thetaCopy
}()
}
for {
select {
case d := <-resultPipe:
results = append(results, d)
if len(results) == procs {
return averageTheta(results)
}
}
}
}
func averageTheta(matrices []*mat64.Dense) *mat64.Dense {
if len(matrices) < 2 {
panic("Must provide at least two matrices to average")
}
invLen := 1.0 / float64(len(matrices))
// Helper function for scalar multiplication
mult := func(r, c int, v float64) float64 { return v * invLen}
// Sum matrices
average := matrices[0]
for i := 1; i < len(matrices); i++ {
average.Add(average, matrices[i])
}
// Calculate the average
average.Apply(mult, average)
return average
}

View File

@ -1,38 +0,0 @@
package optimisation
import (
"testing"
"github.com/gonum/blas/cblas"
"github.com/gonum/matrix/mat64"
. "github.com/smartystreets/goconvey/convey"
)
func init() {
mat64.Register(cblas.Blas{})
}
func TestGradientDescent(t *testing.T) {
Convey("When y = 2x_0 + 2x_1", t, func() {
x := mat64.NewDense(2, 2, []float64{1, 3, 5, 8})
y := mat64.NewDense(2, 1, []float64{8, 26})
Convey("When estimating the parameters with Batch Gradient Descent", func() {
theta := mat64.NewDense(2, 1, []float64{0, 0})
results := BatchGradientDescent(x, y, theta, 0.005, 10000)
Convey("The estimated parameters should be really close to 2, 2", func() {
So(results.At(0, 0), ShouldAlmostEqual, 2.0, 0.01)
})
})
Convey("When estimating the parameters with Stochastic Gradient Descent", func() {
theta := mat64.NewDense(2, 1, []float64{0, 0})
results := StochasticGradientDescent(x, y, theta, 0.005, 10000, 30)
Convey("The estimated parameters should be really close to 2, 2", func() {
So(results.At(0, 0), ShouldAlmostEqual, 2.0, 0.01)
})
})
})
}

View File

@ -40,3 +40,8 @@ func SortIntMap(m map[int]float64) []int {
func FloatsToMatrix(floats []float64) *mat64.Dense { func FloatsToMatrix(floats []float64) *mat64.Dense {
return mat64.NewDense(1, len(floats), floats) return mat64.NewDense(1, len(floats), floats)
} }
func VectorToMatrix(vector *mat64.Vector) *mat64.Dense {
vec := vector.RawVector()
return mat64.NewDense(1, len(vec.Data), vec.Data)
}