1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

linear_models: More idiomatic

This commit is contained in:
Richard Townsend 2014-07-02 15:48:35 +01:00
parent 521844cbb2
commit 477438d972
4 changed files with 65 additions and 25 deletions

View File

@ -1,34 +1,30 @@
package linear_models package linear_models
import ( import (
"testing" "github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
"testing"
) )
func TestLogisticRegression(t *testing.T) { func TestLogisticRegression(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() { Convey("Given labels, a classifier and data", t, func() {
X := [][]float64{ X, err := base.ParseCSVToInstances("train.csv", false)
{0, 0, 0, 1}, So(err, ShouldEqual, nil)
{0, 0, 1, 0}, Y, err := base.ParseCSVToInstances("test.csv", false)
{0, 1, 0, 0}, So(err, ShouldEqual, nil)
{1, 0, 0, 0},
}
y := []float64{-1, -1, 1, 1}
lr := NewLogisticRegression("l2", 1.0, 1e-6) lr := NewLogisticRegression("l2", 1.0, 1e-6)
lr.Fit(X,y) lr.Fit(X)
Convey("When predicting the label of first vector", func() { Convey("When predicting the label of first vector", func() {
pred_x := [][]float64{ {1,1,0,0} } Z := lr.Predict(Y)
pred_y := lr.Predict(pred_x)
Convey("The result should be 1", func() { Convey("The result should be 1", func() {
So(pred_y[0], ShouldEqual, 1.0) So(Z.Get(0, 0), ShouldEqual, 1.0)
}) })
}) })
Convey("When predicting the label of second vector", func() { Convey("When predicting the label of second vector", func() {
pred_x := [][]float64{ {0,0,1,1} } Z := lr.Predict(Y)
pred_y := lr.Predict(pred_x)
Convey("The result should be -1", func() { Convey("The result should be -1", func() {
So(pred_y[0], ShouldEqual, -1.0) So(Z.Get(1, 0), ShouldEqual, -1.0)
}) })
}) })
}) })

View File

@ -1,6 +1,9 @@
package linear_models package linear_models
import "fmt" import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
)
type LogisticRegression struct { type LogisticRegression struct {
param *Parameter param *Parameter
@ -24,16 +27,51 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
return &lr return &lr
} }
func (lr *LogisticRegression) Fit(X [][]float64, y []float64) { func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
prob := NewProblem(X, y, 0) problemVec := make([][]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
problemVecCounter := 0
problemVec[i] = make([]float64, X.Cols-1)
for j := 0; j < X.Cols; j++ {
if j == X.ClassIndex {
continue
}
problemVec[i][problemVecCounter] = X.Get(i, j)
problemVecCounter++
}
}
fmt.Println(problemVec, X)
return problemVec
}
func convertInstancesToLabelVec(X *base.Instances) []float64 {
labelVec := make([]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
labelVec[i] = X.Get(i, X.ClassIndex)
}
return labelVec
}
func (lr *LogisticRegression) Fit(X *base.Instances) {
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)
prob := NewProblem(problemVec, labelVec, 0)
lr.model = Train(prob, lr.param) lr.model = Train(prob, lr.param)
} }
func (lr *LogisticRegression) Predict(X [][]float64) []float64 { func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
n_samples := len(X) ret := X.GeneratePredictionVector()
y := make([]float64, n_samples) row := make([]float64, X.Cols-1)
for i, x := range X { for i := 0; i < X.Rows; i++ {
y[i] = Predict(lr.model, x) rowCounter := 0
for j := 0; j < X.Cols; j++ {
if j != X.ClassIndex {
row[rowCounter] = X.Get(i, j)
rowCounter++
} }
return y }
fmt.Println(Predict(lr.model, row), row)
ret.Set(i, 0, Predict(lr.model, row))
}
return ret
} }

2
linear_models/test.csv Normal file
View File

@ -0,0 +1,2 @@
1.0,1.0,0.0,0.0,1.0
0.0,0.0,1.0,1.0,-1.0
1 1.0 1.0 0.0 0.0 1.0
2 0.0 0.0 1.0 1.0 -1.0

4
linear_models/train.csv Normal file
View File

@ -0,0 +1,4 @@
0.0, 0.0, 0.0, 1.0, -1.0
0.0, 0.0, 1.0, 0.0, -1.0
0.0, 1.0, 0.0, 0.0, 1.0
1.0, 0.0, 0.0, 0.0, 1.0
1 0.0 0.0 0.0 1.0 -1.0
2 0.0 0.0 1.0 0.0 -1.0
3 0.0 1.0 0.0 0.0 1.0
4 1.0 0.0 0.0 0.0 1.0