1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/meta/one_v_all_test.go
2014-10-30 23:28:26 +00:00

50 lines
1.2 KiB
Go

package meta
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/linear_models"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestOneVsAllModel(t *testing.T) {
classifierFunc := func(c string) base.Classifier {
m, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
if err != nil {
panic(err)
}
return m
}
Convey("Given data", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
m := NewOneVsAllModel(classifierFunc)
m.Fit(X)
Convey("The maximum class index should be 2", func() {
So(m.maxClassVal, ShouldEqual, 2)
})
Convey("There should be three of everything...", func() {
So(len(m.filters), ShouldEqual, 3)
So(len(m.classifiers), ShouldEqual, 3)
})
Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
So(err, ShouldEqual, nil)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
fmt.Println(evaluation.GetAccuracy(cf))
fmt.Println(evaluation.GetSummary(cf))
})
})
}