mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
linear_models: More idiomatic
This commit is contained in:
parent
521844cbb2
commit
477438d972
@ -1,34 +1,30 @@
|
|||||||
package linear_models
|
package linear_models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLogisticRegression(t *testing.T) {
|
func TestLogisticRegression(t *testing.T) {
|
||||||
Convey("Given labels, a classifier and data", t, func() {
|
Convey("Given labels, a classifier and data", t, func() {
|
||||||
X := [][]float64{
|
X, err := base.ParseCSVToInstances("train.csv", false)
|
||||||
{0, 0, 0, 1},
|
So(err, ShouldEqual, nil)
|
||||||
{0, 0, 1, 0},
|
Y, err := base.ParseCSVToInstances("test.csv", false)
|
||||||
{0, 1, 0, 0},
|
So(err, ShouldEqual, nil)
|
||||||
{1, 0, 0, 0},
|
|
||||||
}
|
|
||||||
y := []float64{-1, -1, 1, 1}
|
|
||||||
lr := NewLogisticRegression("l2", 1.0, 1e-6)
|
lr := NewLogisticRegression("l2", 1.0, 1e-6)
|
||||||
lr.Fit(X,y)
|
lr.Fit(X)
|
||||||
|
|
||||||
Convey("When predicting the label of first vector", func() {
|
Convey("When predicting the label of first vector", func() {
|
||||||
pred_x := [][]float64{ {1,1,0,0} }
|
Z := lr.Predict(Y)
|
||||||
pred_y := lr.Predict(pred_x)
|
|
||||||
Convey("The result should be 1", func() {
|
Convey("The result should be 1", func() {
|
||||||
So(pred_y[0], ShouldEqual, 1.0)
|
So(Z.Get(0, 0), ShouldEqual, 1.0)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
Convey("When predicting the label of second vector", func() {
|
Convey("When predicting the label of second vector", func() {
|
||||||
pred_x := [][]float64{ {0,0,1,1} }
|
Z := lr.Predict(Y)
|
||||||
pred_y := lr.Predict(pred_x)
|
|
||||||
Convey("The result should be -1", func() {
|
Convey("The result should be -1", func() {
|
||||||
So(pred_y[0], ShouldEqual, -1.0)
|
So(Z.Get(1, 0), ShouldEqual, -1.0)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package linear_models
|
package linear_models
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
|
)
|
||||||
|
|
||||||
type LogisticRegression struct {
|
type LogisticRegression struct {
|
||||||
param *Parameter
|
param *Parameter
|
||||||
@ -24,16 +27,51 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr
|
|||||||
return &lr
|
return &lr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lr *LogisticRegression) Fit(X [][]float64, y []float64) {
|
func convertInstancesToProblemVec(X *base.Instances) [][]float64 {
|
||||||
prob := NewProblem(X, y, 0)
|
problemVec := make([][]float64, X.Rows)
|
||||||
|
for i := 0; i < X.Rows; i++ {
|
||||||
|
problemVecCounter := 0
|
||||||
|
problemVec[i] = make([]float64, X.Cols-1)
|
||||||
|
for j := 0; j < X.Cols; j++ {
|
||||||
|
if j == X.ClassIndex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
problemVec[i][problemVecCounter] = X.Get(i, j)
|
||||||
|
problemVecCounter++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println(problemVec, X)
|
||||||
|
return problemVec
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertInstancesToLabelVec(X *base.Instances) []float64 {
|
||||||
|
labelVec := make([]float64, X.Rows)
|
||||||
|
for i := 0; i < X.Rows; i++ {
|
||||||
|
labelVec[i] = X.Get(i, X.ClassIndex)
|
||||||
|
}
|
||||||
|
return labelVec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *LogisticRegression) Fit(X *base.Instances) {
|
||||||
|
problemVec := convertInstancesToProblemVec(X)
|
||||||
|
labelVec := convertInstancesToLabelVec(X)
|
||||||
|
prob := NewProblem(problemVec, labelVec, 0)
|
||||||
lr.model = Train(prob, lr.param)
|
lr.model = Train(prob, lr.param)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lr *LogisticRegression) Predict(X [][]float64) []float64 {
|
func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances {
|
||||||
n_samples := len(X)
|
ret := X.GeneratePredictionVector()
|
||||||
y := make([]float64, n_samples)
|
row := make([]float64, X.Cols-1)
|
||||||
for i, x := range X {
|
for i := 0; i < X.Rows; i++ {
|
||||||
y[i] = Predict(lr.model, x)
|
rowCounter := 0
|
||||||
|
for j := 0; j < X.Cols; j++ {
|
||||||
|
if j != X.ClassIndex {
|
||||||
|
row[rowCounter] = X.Get(i, j)
|
||||||
|
rowCounter++
|
||||||
}
|
}
|
||||||
return y
|
}
|
||||||
|
fmt.Println(Predict(lr.model, row), row)
|
||||||
|
ret.Set(i, 0, Predict(lr.model, row))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
2
linear_models/test.csv
Normal file
2
linear_models/test.csv
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
1.0,1.0,0.0,0.0,1.0
|
||||||
|
0.0,0.0,1.0,1.0,-1.0
|
|
4
linear_models/train.csv
Normal file
4
linear_models/train.csv
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
0.0, 0.0, 0.0, 1.0, -1.0
|
||||||
|
0.0, 0.0, 1.0, 0.0, -1.0
|
||||||
|
0.0, 1.0, 0.0, 0.0, 1.0
|
||||||
|
1.0, 0.0, 0.0, 0.0, 1.0
|
|
Loading…
x
Reference in New Issue
Block a user