mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
linear_models: More idiomatic
This commit is contained in:
parent
521844cbb2
commit
477438d972
@ -1,34 +1,30 @@
|
||||
package linear_models
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogisticRegression(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
X := [][]float64{
|
||||
{0, 0, 0, 1},
|
||||
{0, 0, 1, 0},
|
||||
{0, 1, 0, 0},
|
||||
{1, 0, 0, 0},
|
||||
}
|
||||
y := []float64{-1, -1, 1, 1}
|
||||
X, err := base.ParseCSVToInstances("train.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
Y, err := base.ParseCSVToInstances("test.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
lr := NewLogisticRegression("l2", 1.0, 1e-6)
|
||||
lr.Fit(X,y)
|
||||
lr.Fit(X)
|
||||
|
||||
Convey("When predicting the label of first vector", func() {
|
||||
pred_x := [][]float64{ {1,1,0,0} }
|
||||
pred_y := lr.Predict(pred_x)
|
||||
Z := lr.Predict(Y)
|
||||
Convey("The result should be 1", func() {
|
||||
So(pred_y[0], ShouldEqual, 1.0)
|
||||
So(Z.Get(0, 0), ShouldEqual, 1.0)
|
||||
})
|
||||
})
|
||||
Convey("When predicting the label of second vector", func() {
|
||||
pred_x := [][]float64{ {0,0,1,1} }
|
||||
pred_y := lr.Predict(pred_x)
|
||||
Z := lr.Predict(Y)
|
||||
Convey("The result should be -1", func() {
|
||||
So(pred_y[0], ShouldEqual, -1.0)
|
||||
So(Z.Get(1, 0), ShouldEqual, -1.0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -1,6 +1,9 @@
|
||||
package linear_models
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
)
|
||||
|
||||
type LogisticRegression struct {
|
||||
param *Parameter
|
||||
@ -24,16 +27,51 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
|
||||
return &lr
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Fit(X [][]float64, y []float64) {
|
||||
prob := NewProblem(X, y, 0)
|
||||
func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
|
||||
problemVec := make([][]float64, X.Rows)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
problemVecCounter := 0
|
||||
problemVec[i] = make([]float64, X.Cols-1)
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j == X.ClassIndex {
|
||||
continue
|
||||
}
|
||||
problemVec[i][problemVecCounter] = X.Get(i, j)
|
||||
problemVecCounter++
|
||||
}
|
||||
}
|
||||
fmt.Println(problemVec, X)
|
||||
return problemVec
|
||||
}
|
||||
|
||||
func convertInstancesToLabelVec(X *base.Instances) []float64 {
|
||||
labelVec := make([]float64, X.Rows)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
labelVec[i] = X.Get(i, X.ClassIndex)
|
||||
}
|
||||
return labelVec
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Fit(X *base.Instances) {
|
||||
problemVec := convertInstancesToProblemVec(X)
|
||||
labelVec := convertInstancesToLabelVec(X)
|
||||
prob := NewProblem(problemVec, labelVec, 0)
|
||||
lr.model = Train(prob, lr.param)
|
||||
}
|
||||
|
||||
func (lr *LogisticRegression) Predict(X [][]float64) []float64 {
|
||||
n_samples := len(X)
|
||||
y := make([]float64, n_samples)
|
||||
for i, x := range X {
|
||||
y[i] = Predict(lr.model, x)
|
||||
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
|
||||
ret := X.GeneratePredictionVector()
|
||||
row := make([]float64, X.Cols-1)
|
||||
for i := 0; i < X.Rows; i++ {
|
||||
rowCounter := 0
|
||||
for j := 0; j < X.Cols; j++ {
|
||||
if j != X.ClassIndex {
|
||||
row[rowCounter] = X.Get(i, j)
|
||||
rowCounter++
|
||||
}
|
||||
}
|
||||
fmt.Println(Predict(lr.model, row), row)
|
||||
ret.Set(i, 0, Predict(lr.model, row))
|
||||
}
|
||||
return y
|
||||
return ret
|
||||
}
|
||||
|
2
linear_models/test.csv
Normal file
2
linear_models/test.csv
Normal file
@ -0,0 +1,2 @@
|
||||
1.0,1.0,0.0,0.0,1.0
|
||||
0.0,0.0,1.0,1.0,-1.0
|
|
4
linear_models/train.csv
Normal file
4
linear_models/train.csv
Normal file
@ -0,0 +1,4 @@
|
||||
0.0, 0.0, 0.0, 1.0, -1.0
|
||||
0.0, 0.0, 1.0, 0.0, -1.0
|
||||
0.0, 1.0, 0.0, 0.0, 1.0
|
||||
1.0, 0.0, 0.0, 0.0, 1.0
|
|
Loading…
x
Reference in New Issue
Block a user