mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
71 lines
1.8 KiB
Go
71 lines
1.8 KiB
Go
package meta
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
"github.com/sjwhitworth/golearn/linear_models"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"io/ioutil"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestOneVsAllModel(t *testing.T) {
|
|
|
|
classifierFunc := func(c string) base.Classifier {
|
|
m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return m
|
|
}
|
|
|
|
Convey("Given data", t, func() {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
|
|
|
m := NewOneVsAllModel(classifierFunc)
|
|
m.Fit(X)
|
|
|
|
Convey("The maximum class index should be 2", func() {
|
|
So(m.maxClassVal, ShouldEqual, 2)
|
|
})
|
|
|
|
Convey("There should be three of everything...", func() {
|
|
So(len(m.filters), ShouldEqual, 3)
|
|
So(len(m.classifiers), ShouldEqual, 3)
|
|
})
|
|
|
|
Convey("Predictions should work...", func() {
|
|
predictions, err := m.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
|
So(err, ShouldEqual, nil)
|
|
fmt.Println(evaluation.GetAccuracy(cf))
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
})
|
|
|
|
Convey("Saving and reloading should work...", func() {
|
|
predictions, err := m.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
f, err := ioutil.TempFile(os.TempDir(), "tmpCls")
|
|
defer func() {
|
|
f.Close()
|
|
}()
|
|
err = m.Save(f.Name())
|
|
So(err, ShouldBeNil)
|
|
Convey("Reloaded classifier should output the same predictions", func() {
|
|
m := NewOneVsAllModel(classifierFunc)
|
|
err := m.Load(f.Name())
|
|
So(err, ShouldBeNil)
|
|
newPredictions, err := m.Predict(Y)
|
|
So(err, ShouldBeNil)
|
|
So(base.InstancesAreEqual(predictions, newPredictions), ShouldBeTrue)
|
|
})
|
|
})
|
|
})
|
|
}
|