1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
This commit is contained in:
yenck 2018-06-08 04:46:27 -07:00 committed by ss8651twtw
parent 6338cc3656
commit 77b8faa22c
4 changed files with 95 additions and 1 deletions

View File

@ -264,7 +264,6 @@ func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix st
defer func() {
f.Close()
}()
err = Export(lr.model, f.Name())
if err != nil {
return base.DescribeError("Error exporting model", err)

View File

@ -0,0 +1,57 @@
package linear_models
import (
//"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestLinearSVC(t *testing.T) {
Convey("Doing a LinearSVC test", t, func(){
_, err := NewLinearSVC("l1", "l1", false, 1.0, -1e6)
So(err, ShouldNotBeNil)
_, err = NewLinearSVC("l0", "l1", false, 1.0, -1e6)
So(err, ShouldNotBeNil)
_, err = NewLinearSVC("l1", "l0", false, 1.0, -1e6)
So(err, ShouldNotBeNil)
_, err = NewLinearSVC("l1", "l2", false, 1.0, -1e6)
So(err, ShouldNotBeNil)
SVC, err := NewLinearSVC("l1", "l2", true, 1.0, -1e6)
So(SVC, ShouldNotBeNil)
So(err, ShouldBeNil)
_, err = NewLinearSVC("l2", "l2", false, 1.0, -1e6)
So(err, ShouldBeNil)
_, err = NewLinearSVC("l2", "l2", true, 1.0, -1e6)
So(err, ShouldBeNil)
_, err = NewLinearSVC("l2", "l1", false, 1.0, -1e6)
So(err, ShouldBeNil)
_, err = NewLinearSVC("l2", "l1", true, 1.0, -1e6)
So(err, ShouldNotBeNil)
So(func(){ SVC.GetMetadata() } , ShouldNotPanic)
//err = SVC.Save("tmp")
params := &LinearSVCParams{0, []float64{0.0}, 1.0, -1e6, false, false}
params = params.Copy()
So(params, ShouldNotBeNil)
g := [][]float64{{1, 2}, {1, 2}, {1, 2}}
v := []float64{1, 2}
var bias float64
problem := NewProblem(g[:], v[:], bias)
param := NewParameter(0, 1.0, -1e6)
model := Train(problem, param)
So(model, ShouldNotBeNil)
err = Export(model, "tmp")
So(err, ShouldBeNil)
err = Load(model, "tmp")
So(err, ShouldBeNil)
SVC.model = model
err = SVC.Save("tmp")
So(err, ShouldBeNil)
err = SVC.Load("tmp")
So(err, ShouldBeNil)
s := SVC.String()
So(s, ShouldEqual, "LogisticSVC")
})
}

View File

@ -0,0 +1,38 @@
package linear_models
import (
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestLogistic(t *testing.T) {
Convey("Doing a logistic test", t, func(){
X, err := base.ParseCSVToInstances("train.csv", false)
So(err, ShouldEqual, nil)
Y, err := base.ParseCSVToInstances("test.csv", false)
So(err, ShouldEqual, nil)
_, err = NewLogisticRegression("l0", 1.0, 1e-6)
So(err, ShouldNotBeNil)
lr, err := NewLogisticRegression("l1", 1.0, 1e-6)
So(err, ShouldBeNil)
lr.Fit(X)
Convey("When predicting the label of first vector", func() {
Z, err := lr.Predict(Y)
So(err, ShouldEqual, nil)
Convey("The result should be 1", func() {
So(Z.RowString(0), ShouldEqual, "-1.0")
})
})
Convey("When predicting the label of second vector", func() {
Z, err := lr.Predict(Y)
So(err, ShouldEqual, nil)
Convey("The result should be -1", func() {
So(Z.RowString(1), ShouldEqual, "-1.0")
})
})
So((*lr).String(), ShouldEqual, "LogisticRegression")
})
}

BIN
linear_models/tmp Normal file

Binary file not shown.