1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

linear_models: More idiomatic

This commit is contained in:
Richard Townsend 2014-07-02 15:48:35 +01:00
parent 521844cbb2
commit 477438d972
4 changed files with 65 additions and 25 deletions

View File

@ -1,34 +1,30 @@
package linear_models
package linear_models
import (
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestLogisticRegression(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
X := [][]float64{
{0, 0, 0, 1},
{0, 0, 1, 0},
{0, 1, 0, 0},
{1, 0, 0, 0},
}
y := []float64{-1, -1, 1, 1}
X, err := base.ParseCSVToInstances("train.csv", false)
So(err, ShouldEqual, nil)
Y, err := base.ParseCSVToInstances("test.csv", false)
So(err, ShouldEqual, nil)
lr := NewLogisticRegression("l2", 1.0, 1e-6)
lr.Fit(X,y)
lr.Fit(X)
Convey("When predicting the label of first vector", func() {
pred_x := [][]float64{ {1,1,0,0} }
pred_y := lr.Predict(pred_x)
Z := lr.Predict(Y)
Convey("The result should be 1", func() {
So(pred_y[0], ShouldEqual, 1.0)
So(Z.Get(0, 0), ShouldEqual, 1.0)
})
})
Convey("When predicting the label of second vector", func() {
pred_x := [][]float64{ {0,0,1,1} }
pred_y := lr.Predict(pred_x)
Z := lr.Predict(Y)
Convey("The result should be -1", func() {
So(pred_y[0], ShouldEqual, -1.0)
So(Z.Get(1, 0), ShouldEqual, -1.0)
})
})
})

View File

@ -1,6 +1,9 @@
package linear_models
import "fmt"
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
)
type LogisticRegression struct {
param *Parameter
@ -24,16 +27,51 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
return &lr
}
func (lr *LogisticRegression) Fit(X [][]float64, y []float64) {
prob := NewProblem(X, y, 0)
func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
problemVec := make([][]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
problemVecCounter := 0
problemVec[i] = make([]float64, X.Cols-1)
for j := 0; j < X.Cols; j++ {
if j == X.ClassIndex {
continue
}
problemVec[i][problemVecCounter] = X.Get(i, j)
problemVecCounter++
}
}
fmt.Println(problemVec, X)
return problemVec
}
func convertInstancesToLabelVec(X *base.Instances) []float64 {
labelVec := make([]float64, X.Rows)
for i := 0; i < X.Rows; i++ {
labelVec[i] = X.Get(i, X.ClassIndex)
}
return labelVec
}
func (lr *LogisticRegression) Fit(X *base.Instances) {
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)
prob := NewProblem(problemVec, labelVec, 0)
lr.model = Train(prob, lr.param)
}
func (lr *LogisticRegression) Predict(X [][]float64) []float64 {
n_samples := len(X)
y := make([]float64, n_samples)
for i, x := range X {
y[i] = Predict(lr.model, x)
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
ret := X.GeneratePredictionVector()
row := make([]float64, X.Cols-1)
for i := 0; i < X.Rows; i++ {
rowCounter := 0
for j := 0; j < X.Cols; j++ {
if j != X.ClassIndex {
row[rowCounter] = X.Get(i, j)
rowCounter++
}
}
fmt.Println(Predict(lr.model, row), row)
ret.Set(i, 0, Predict(lr.model, row))
}
return y
return ret
}

2
linear_models/test.csv Normal file
View File

@ -0,0 +1,2 @@
1.0,1.0,0.0,0.0,1.0
0.0,0.0,1.0,1.0,-1.0
1 1.0 1.0 0.0 0.0 1.0
2 0.0 0.0 1.0 1.0 -1.0

4
linear_models/train.csv Normal file
View File

@ -0,0 +1,4 @@
0.0, 0.0, 0.0, 1.0, -1.0
0.0, 0.0, 1.0, 0.0, -1.0
0.0, 1.0, 0.0, 0.0, 1.0
1.0, 0.0, 0.0, 0.0, 1.0
1 0.0 0.0 0.0 1.0 -1.0
2 0.0 0.0 1.0 0.0 -1.0
3 0.0 1.0 0.0 0.0 1.0
4 1.0 0.0 0.0 0.0 1.0