mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
C0
This commit is contained in:
parent
6338cc3656
commit
77b8faa22c
@ -264,7 +264,6 @@ func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix st
|
||||
defer func() {
|
||||
f.Close()
|
||||
}()
|
||||
|
||||
err = Export(lr.model, f.Name())
|
||||
if err != nil {
|
||||
return base.DescribeError("Error exporting model", err)
|
||||
|
57
linear_models/linearsvc_test.go
Normal file
57
linear_models/linearsvc_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
//"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLinearSVC(t *testing.T) {
|
||||
Convey("Doing a LinearSVC test", t, func(){
|
||||
_, err := NewLinearSVC("l1", "l1", false, 1.0, -1e6)
|
||||
So(err, ShouldNotBeNil)
|
||||
_, err = NewLinearSVC("l0", "l1", false, 1.0, -1e6)
|
||||
So(err, ShouldNotBeNil)
|
||||
_, err = NewLinearSVC("l1", "l0", false, 1.0, -1e6)
|
||||
So(err, ShouldNotBeNil)
|
||||
_, err = NewLinearSVC("l1", "l2", false, 1.0, -1e6)
|
||||
So(err, ShouldNotBeNil)
|
||||
SVC, err := NewLinearSVC("l1", "l2", true, 1.0, -1e6)
|
||||
So(SVC, ShouldNotBeNil)
|
||||
So(err, ShouldBeNil)
|
||||
_, err = NewLinearSVC("l2", "l2", false, 1.0, -1e6)
|
||||
So(err, ShouldBeNil)
|
||||
_, err = NewLinearSVC("l2", "l2", true, 1.0, -1e6)
|
||||
So(err, ShouldBeNil)
|
||||
_, err = NewLinearSVC("l2", "l1", false, 1.0, -1e6)
|
||||
So(err, ShouldBeNil)
|
||||
_, err = NewLinearSVC("l2", "l1", true, 1.0, -1e6)
|
||||
So(err, ShouldNotBeNil)
|
||||
So(func(){ SVC.GetMetadata() } , ShouldNotPanic)
|
||||
//err = SVC.Save("tmp")
|
||||
|
||||
params := &LinearSVCParams{0, []float64{0.0}, 1.0, -1e6, false, false}
|
||||
params = params.Copy()
|
||||
So(params, ShouldNotBeNil)
|
||||
|
||||
g := [][]float64{{1, 2}, {1, 2}, {1, 2}}
|
||||
v := []float64{1, 2}
|
||||
var bias float64
|
||||
problem := NewProblem(g[:], v[:], bias)
|
||||
param := NewParameter(0, 1.0, -1e6)
|
||||
model := Train(problem, param)
|
||||
So(model, ShouldNotBeNil)
|
||||
err = Export(model, "tmp")
|
||||
So(err, ShouldBeNil)
|
||||
err = Load(model, "tmp")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
SVC.model = model
|
||||
err = SVC.Save("tmp")
|
||||
So(err, ShouldBeNil)
|
||||
err = SVC.Load("tmp")
|
||||
So(err, ShouldBeNil)
|
||||
s := SVC.String()
|
||||
So(s, ShouldEqual, "LogisticSVC")
|
||||
})
|
||||
}
|
38
linear_models/logistic_test.go
Normal file
38
linear_models/logistic_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package linear_models
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogistic(t *testing.T) {
|
||||
Convey("Doing a logistic test", t, func(){
|
||||
X, err := base.ParseCSVToInstances("train.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
Y, err := base.ParseCSVToInstances("test.csv", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
_, err = NewLogisticRegression("l0", 1.0, 1e-6)
|
||||
So(err, ShouldNotBeNil)
|
||||
lr, err := NewLogisticRegression("l1", 1.0, 1e-6)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
lr.Fit(X)
|
||||
|
||||
Convey("When predicting the label of first vector", func() {
|
||||
Z, err := lr.Predict(Y)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("The result should be 1", func() {
|
||||
So(Z.RowString(0), ShouldEqual, "-1.0")
|
||||
})
|
||||
})
|
||||
Convey("When predicting the label of second vector", func() {
|
||||
Z, err := lr.Predict(Y)
|
||||
So(err, ShouldEqual, nil)
|
||||
Convey("The result should be -1", func() {
|
||||
So(Z.RowString(1), ShouldEqual, "-1.0")
|
||||
})
|
||||
})
|
||||
So((*lr).String(), ShouldEqual, "LogisticRegression")
|
||||
})
|
||||
}
|
BIN
linear_models/tmp
Normal file
BIN
linear_models/tmp
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user