From 477438d9721de7f1cf9b3ca0aeb8fb173802e390 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Wed, 2 Jul 2014 15:48:35 +0100 Subject: [PATCH] linear_models: More idiomatic --- linear_models/linear_models_test.go | 28 +++++++-------- linear_models/logistic.go | 56 ++++++++++++++++++++++++----- linear_models/test.csv | 2 ++ linear_models/train.csv | 4 +++ 4 files changed, 65 insertions(+), 25 deletions(-) create mode 100644 linear_models/test.csv create mode 100644 linear_models/train.csv diff --git a/linear_models/linear_models_test.go b/linear_models/linear_models_test.go index c3e2191..139db7a 100644 --- a/linear_models/linear_models_test.go +++ b/linear_models/linear_models_test.go @@ -1,34 +1,30 @@ -package linear_models +package linear_models import ( - "testing" + "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" + "testing" ) func TestLogisticRegression(t *testing.T) { Convey("Given labels, a classifier and data", t, func() { - X := [][]float64{ - {0, 0, 0, 1}, - {0, 0, 1, 0}, - {0, 1, 0, 0}, - {1, 0, 0, 0}, - } - y := []float64{-1, -1, 1, 1} + X, err := base.ParseCSVToInstances("train.csv", false) + So(err, ShouldEqual, nil) + Y, err := base.ParseCSVToInstances("test.csv", false) + So(err, ShouldEqual, nil) lr := NewLogisticRegression("l2", 1.0, 1e-6) - lr.Fit(X,y) + lr.Fit(X) Convey("When predicting the label of first vector", func() { - pred_x := [][]float64{ {1,1,0,0} } - pred_y := lr.Predict(pred_x) + Z := lr.Predict(Y) Convey("The result should be 1", func() { - So(pred_y[0], ShouldEqual, 1.0) + So(Z.Get(0, 0), ShouldEqual, 1.0) }) }) Convey("When predicting the label of second vector", func() { - pred_x := [][]float64{ {0,0,1,1} } - pred_y := lr.Predict(pred_x) + Z := lr.Predict(Y) Convey("The result should be -1", func() { - So(pred_y[0], ShouldEqual, -1.0) + So(Z.Get(1, 0), ShouldEqual, -1.0) }) }) }) diff --git a/linear_models/logistic.go b/linear_models/logistic.go index 07d19ec..324c678 100644 --- a/linear_models/logistic.go +++ b/linear_models/logistic.go @@ -1,6 +1,9 @@ package linear_models -import "fmt" +import ( + "fmt" + base "github.com/sjwhitworth/golearn/base" +) type LogisticRegression struct { param *Parameter @@ -24,16 +27,51 @@ func NewLogisticRegression(penalty string, C float64, eps float64) *LogisticRegr return &lr } -func (lr *LogisticRegression) Fit(X [][]float64, y []float64) { - prob := NewProblem(X, y, 0) +func convertInstancesToProblemVec(X *base.Instances) [][]float64 { + problemVec := make([][]float64, X.Rows) + for i := 0; i < X.Rows; i++ { + problemVecCounter := 0 + problemVec[i] = make([]float64, X.Cols-1) + for j := 0; j < X.Cols; j++ { + if j == X.ClassIndex { + continue + } + problemVec[i][problemVecCounter] = X.Get(i, j) + problemVecCounter++ + } + } + fmt.Println(problemVec, X) + return problemVec +} + +func convertInstancesToLabelVec(X *base.Instances) []float64 { + labelVec := make([]float64, X.Rows) + for i := 0; i < X.Rows; i++ { + labelVec[i] = X.Get(i, X.ClassIndex) + } + return labelVec +} + +func (lr *LogisticRegression) Fit(X *base.Instances) { + problemVec := convertInstancesToProblemVec(X) + labelVec := convertInstancesToLabelVec(X) + prob := NewProblem(problemVec, labelVec, 0) lr.model = Train(prob, lr.param) } -func (lr *LogisticRegression) Predict(X [][]float64) []float64 { - n_samples := len(X) - y := make([]float64, n_samples) - for i, x := range X { - y[i] = Predict(lr.model, x) +func (lr *LogisticRegression) Predict(X *base.Instances) *base.Instances { + ret := X.GeneratePredictionVector() + row := make([]float64, X.Cols-1) + for i := 0; i < X.Rows; i++ { + rowCounter := 0 + for j := 0; j < X.Cols; j++ { + if j != X.ClassIndex { + row[rowCounter] = X.Get(i, j) + rowCounter++ + } + } + fmt.Println(Predict(lr.model, row), row) + ret.Set(i, 0, Predict(lr.model, row)) } - return y + return ret } diff --git a/linear_models/test.csv b/linear_models/test.csv new file mode 100644 index 0000000..d2ac5e5 --- /dev/null +++ b/linear_models/test.csv @@ -0,0 +1,2 @@ +1.0,1.0,0.0,0.0,1.0 +0.0,0.0,1.0,1.0,-1.0 \ No newline at end of file diff --git a/linear_models/train.csv b/linear_models/train.csv new file mode 100644 index 0000000..6c0a1f1 --- /dev/null +++ b/linear_models/train.csv @@ -0,0 +1,4 @@ +0.0, 0.0, 0.0, 1.0, -1.0 +0.0, 0.0, 1.0, 0.0, -1.0 +0.0, 1.0, 0.0, 0.0, 1.0 +1.0, 0.0, 0.0, 0.0, 1.0 \ No newline at end of file