mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
C0
This commit is contained in:
parent
6338cc3656
commit
77b8faa22c
@ -264,7 +264,6 @@ func (lr *LinearSVC) SaveWithPrefix(writer *base.ClassifierSerializer, prefix st
|
|||||||
defer func() {
|
defer func() {
|
||||||
f.Close()
|
f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = Export(lr.model, f.Name())
|
err = Export(lr.model, f.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return base.DescribeError("Error exporting model", err)
|
return base.DescribeError("Error exporting model", err)
|
||||||
|
57
linear_models/linearsvc_test.go
Normal file
57
linear_models/linearsvc_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package linear_models
|
||||||
|
|
||||||
|
import (
|
||||||
|
//"github.com/sjwhitworth/golearn/base"
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLinearSVC(t *testing.T) {
|
||||||
|
Convey("Doing a LinearSVC test", t, func(){
|
||||||
|
_, err := NewLinearSVC("l1", "l1", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
_, err = NewLinearSVC("l0", "l1", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
_, err = NewLinearSVC("l1", "l0", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
_, err = NewLinearSVC("l1", "l2", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
SVC, err := NewLinearSVC("l1", "l2", true, 1.0, -1e6)
|
||||||
|
So(SVC, ShouldNotBeNil)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
_, err = NewLinearSVC("l2", "l2", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
_, err = NewLinearSVC("l2", "l2", true, 1.0, -1e6)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
_, err = NewLinearSVC("l2", "l1", false, 1.0, -1e6)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
_, err = NewLinearSVC("l2", "l1", true, 1.0, -1e6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
So(func(){ SVC.GetMetadata() } , ShouldNotPanic)
|
||||||
|
//err = SVC.Save("tmp")
|
||||||
|
|
||||||
|
params := &LinearSVCParams{0, []float64{0.0}, 1.0, -1e6, false, false}
|
||||||
|
params = params.Copy()
|
||||||
|
So(params, ShouldNotBeNil)
|
||||||
|
|
||||||
|
g := [][]float64{{1, 2}, {1, 2}, {1, 2}}
|
||||||
|
v := []float64{1, 2}
|
||||||
|
var bias float64
|
||||||
|
problem := NewProblem(g[:], v[:], bias)
|
||||||
|
param := NewParameter(0, 1.0, -1e6)
|
||||||
|
model := Train(problem, param)
|
||||||
|
So(model, ShouldNotBeNil)
|
||||||
|
err = Export(model, "tmp")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
err = Load(model, "tmp")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
SVC.model = model
|
||||||
|
err = SVC.Save("tmp")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
err = SVC.Load("tmp")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
s := SVC.String()
|
||||||
|
So(s, ShouldEqual, "LogisticSVC")
|
||||||
|
})
|
||||||
|
}
|
38
linear_models/logistic_test.go
Normal file
38
linear_models/logistic_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package linear_models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogistic(t *testing.T) {
|
||||||
|
Convey("Doing a logistic test", t, func(){
|
||||||
|
X, err := base.ParseCSVToInstances("train.csv", false)
|
||||||
|
So(err, ShouldEqual, nil)
|
||||||
|
Y, err := base.ParseCSVToInstances("test.csv", false)
|
||||||
|
So(err, ShouldEqual, nil)
|
||||||
|
_, err = NewLogisticRegression("l0", 1.0, 1e-6)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
lr, err := NewLogisticRegression("l1", 1.0, 1e-6)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
lr.Fit(X)
|
||||||
|
|
||||||
|
Convey("When predicting the label of first vector", func() {
|
||||||
|
Z, err := lr.Predict(Y)
|
||||||
|
So(err, ShouldEqual, nil)
|
||||||
|
Convey("The result should be 1", func() {
|
||||||
|
So(Z.RowString(0), ShouldEqual, "-1.0")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
Convey("When predicting the label of second vector", func() {
|
||||||
|
Z, err := lr.Predict(Y)
|
||||||
|
So(err, ShouldEqual, nil)
|
||||||
|
Convey("The result should be -1", func() {
|
||||||
|
So(Z.RowString(1), ShouldEqual, "-1.0")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
So((*lr).String(), ShouldEqual, "LogisticRegression")
|
||||||
|
})
|
||||||
|
}
|
BIN
linear_models/tmp
Normal file
BIN
linear_models/tmp
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user