From 77b8faa22cac573b10377bcfcb87d9fb7c410504 Mon Sep 17 00:00:00 2001 From: yenck Date: Fri, 8 Jun 2018 04:46:27 -0700 Subject: [PATCH] C0 --- linear_models/linearsvc.go | 1 - linear_models/linearsvc_test.go | 57 ++++++++++++++++++++++++++++++++ linear_models/logistic_test.go | 38 +++++++++++++++++++++ linear_models/tmp | Bin 0 -> 484 bytes 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 linear_models/linearsvc_test.go create mode 100644 linear_models/logistic_test.go create mode 100644 linear_models/tmp diff --git a/linear_models/linearsvc.go b/linear_models/linearsvc.go index 5016218..f059e83 100644 --- a/linear_models/linearsvc.go +++ b/linear_models/linearsvc.go @@ -264,7 +264,6 @@ func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix st defer func() { f.Close() }() - err = Export(lr.model, f.Name()) if err != nil { return base.DescribeError("Error exporting model", err) diff --git a/linear_models/linearsvc_test.go b/linear_models/linearsvc_test.go new file mode 100644 index 0000000..a76825a --- /dev/null +++ b/linear_models/linearsvc_test.go @@ -0,0 +1,57 @@ +package linear_models + +import ( + //"github.com/sjwhitworth/golearn/base" + . "github.com/smartystreets/goconvey/convey" + "testing" +) + +func TestLinearSVC(t *testing.T) { + Convey("Doing a LinearSVC test", t, func(){ + _, err := NewLinearSVC("l1", "l1", false, 1.0, -1e6) + So(err, ShouldNotBeNil) + _, err = NewLinearSVC("l0", "l1", false, 1.0, -1e6) + So(err, ShouldNotBeNil) + _, err = NewLinearSVC("l1", "l0", false, 1.0, -1e6) + So(err, ShouldNotBeNil) + _, err = NewLinearSVC("l1", "l2", false, 1.0, -1e6) + So(err, ShouldNotBeNil) + SVC, err := NewLinearSVC("l1", "l2", true, 1.0, -1e6) + So(SVC, ShouldNotBeNil) + So(err, ShouldBeNil) + _, err = NewLinearSVC("l2", "l2", false, 1.0, -1e6) + So(err, ShouldBeNil) + _, err = NewLinearSVC("l2", "l2", true, 1.0, -1e6) + So(err, ShouldBeNil) + _, err = NewLinearSVC("l2", "l1", false, 1.0, -1e6) + So(err, ShouldBeNil) + _, err = NewLinearSVC("l2", "l1", true, 1.0, -1e6) + So(err, ShouldNotBeNil) + So(func(){ SVC.GetMetadata() } , ShouldNotPanic) + //err = SVC.Save("tmp") + + params := &LinearSVCParams{0, []float64{0.0}, 1.0, -1e6, false, false} + params = params.Copy() + So(params, ShouldNotBeNil) + + g := [][]float64{{1, 2}, {1, 2}, {1, 2}} + v := []float64{1, 2} + var bias float64 + problem := NewProblem(g[:], v[:], bias) + param := NewParameter(0, 1.0, -1e6) + model := Train(problem, param) + So(model, ShouldNotBeNil) + err = Export(model, "tmp") + So(err, ShouldBeNil) + err = Load(model, "tmp") + So(err, ShouldBeNil) + + SVC.model = model + err = SVC.Save("tmp") + So(err, ShouldBeNil) + err = SVC.Load("tmp") + So(err, ShouldBeNil) + s := SVC.String() + So(s, ShouldEqual, "LogisticSVC") + }) +} diff --git a/linear_models/logistic_test.go b/linear_models/logistic_test.go new file mode 100644 index 0000000..1656fed --- /dev/null +++ b/linear_models/logistic_test.go @@ -0,0 +1,38 @@ +package linear_models + +import ( + "github.com/sjwhitworth/golearn/base" + . "github.com/smartystreets/goconvey/convey" + "testing" +) + +func TestLogistic(t *testing.T) { + Convey("Doing a logistic test", t, func(){ + X, err := base.ParseCSVToInstances("train.csv", false) + So(err, ShouldEqual, nil) + Y, err := base.ParseCSVToInstances("test.csv", false) + So(err, ShouldEqual, nil) + _, err = NewLogisticRegression("l0", 1.0, 1e-6) + So(err, ShouldNotBeNil) + lr, err := NewLogisticRegression("l1", 1.0, 1e-6) + So(err, ShouldBeNil) + + lr.Fit(X) + + Convey("When predicting the label of first vector", func() { + Z, err := lr.Predict(Y) + So(err, ShouldEqual, nil) + Convey("The result should be 1", func() { + So(Z.RowString(0), ShouldEqual, "-1.0") + }) + }) + Convey("When predicting the label of second vector", func() { + Z, err := lr.Predict(Y) + So(err, ShouldEqual, nil) + Convey("The result should be -1", func() { + So(Z.RowString(1), ShouldEqual, "-1.0") + }) + }) + So((*lr).String(), ShouldEqual, "LogisticRegression") + }) +} diff --git a/linear_models/tmp b/linear_models/tmp new file mode 100644 index 0000000000000000000000000000000000000000..4560a48d7ba89141002965bdd54fed1b020adaaf GIT binary patch literal 484 zcmV`}MMWHrvPG zp_|THfkj6aI=GGnvYm(ir-sQ1ln&Hr=()#$kmTV93aVzgXJFk2$$ z@zESQNC5x<0RR8;^>=agp|Kr=>whEE{BK}xW@b8?|LFmW!MQ)aq_QAY!N({l-Y1AF zuP7c|n=2S|Gh<|r5{7%3Qlgws+JOG=AU6^yx(G82mxbPc)6xpWQm3=PbUj7`i; zEi4TUEG>;q6^IoBX*M!6F)_0=HZ->|GPJZXH&h^2jLSd)1i1z}1cx>N00030{~E=k aU=)m800000|Nj910RR7Vw`K|e6aWAL0^^JT literal 0 HcmV?d00001