mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
106 lines
2.8 KiB
Go
106 lines
2.8 KiB
Go
package ensemble
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"io/ioutil"
|
|
"testing"
|
|
)
|
|
|
|
func TestMultiSVMUnweighted(t *testing.T) {
|
|
Convey("Loading data...", t, func() {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
|
So(err, ShouldBeNil)
|
|
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
|
|
|
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
|
|
m.Fit(X)
|
|
|
|
Convey("Predictions should work...", func() {
|
|
predictions, err := m.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
|
So(err, ShouldEqual, nil)
|
|
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
|
|
})
|
|
|
|
Convey("Saving should work...", func() {
|
|
f, err := ioutil.TempFile("", "tree")
|
|
So(err, ShouldBeNil)
|
|
err = m.Save(f.Name())
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Loading should work...", func() {
|
|
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, nil)
|
|
err := mLoaded.Load(f.Name())
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Predictions should be the same...", func() {
|
|
originalPredictions, err := m.Predict(Y)
|
|
So(err, ShouldBeNil)
|
|
newPredictions, err := mLoaded.Predict(Y)
|
|
So(err, ShouldBeNil)
|
|
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
}
|
|
|
|
func TestMultiSVMWeighted(t *testing.T) {
|
|
Convey("Loading data...", t, func() {
|
|
weights := make(map[string]float64)
|
|
weights["Finance"] = 0.1739
|
|
weights["Tech"] = 0.0750
|
|
weights["Politics"] = 0.4928
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
|
|
So(err, ShouldBeNil)
|
|
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
|
|
|
|
m := NewMultiLinearSVC("l1", "l2", true, 0.62, 1e-4, weights)
|
|
m.Fit(X)
|
|
|
|
Convey("Predictions should work...", func() {
|
|
predictions, err := m.Predict(Y)
|
|
So(err, ShouldEqual, nil)
|
|
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
|
|
So(err, ShouldEqual, nil)
|
|
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.60)
|
|
|
|
Convey("Saving should work...", func() {
|
|
f, err := ioutil.TempFile("", "tree")
|
|
So(err, ShouldBeNil)
|
|
err = m.Save(f.Name())
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Loading should work...", func() {
|
|
mLoaded := NewMultiLinearSVC("l1", "l2", true, 1.00, 1e-8, weights)
|
|
err := mLoaded.Load(f.Name())
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Predictions should be the same...", func() {
|
|
originalPredictions, err := m.Predict(Y)
|
|
So(err, ShouldBeNil)
|
|
newPredictions, err := mLoaded.Predict(Y)
|
|
So(err, ShouldBeNil)
|
|
So(base.InstancesAreEqual(originalPredictions, newPredictions), ShouldBeTrue)
|
|
})
|
|
|
|
})
|
|
})
|
|
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestMultiSVMSaved(t *testing.T) {
|
|
Convey("Loading data...", t, func() {
|
|
|
|
})
|
|
}
|