1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/meta/bagging_test.go
2016-06-14 00:56:47 +01:00

104 lines
2.7 KiB
Go

package meta
import (
"math/rand"
"testing"
"time"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/trees"
. "github.com/smartystreets/goconvey/convey"
)
func BenchmarkBaggingRandomForestFit(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatalf("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Fit(instf)
}
}
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatalf("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(instf)
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(instf)
}
}
func TestBaggedModelRandomForest(t *testing.T) {
Convey("Given data", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Splitting the data into training and test data", func() {
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
Convey("Filtering the split datasets", func() {
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() {
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
predictions := rf.Predict(testDataf)
confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
So(err, ShouldBeNil)
Convey("Predictions are somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5)
})
})
})
})
})
}