1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/meta/bagging_test.go

90 lines
2.3 KiB
Go
Raw Normal View History

package meta
import (
"fmt"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/trees"
"math/rand"
"testing"
"time"
)
2014-05-23 11:56:23 +01:00
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
2014-05-23 11:56:23 +01:00
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
2014-08-02 16:22:15 +01:00
rf.Fit(instf)
2014-05-23 11:56:23 +01:00
}
}
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
2014-05-23 11:56:23 +01:00
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
2014-08-02 16:22:15 +01:00
rf.Fit(instf)
2014-05-23 11:56:23 +01:00
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
2014-08-02 16:22:15 +01:00
rf.Predict(instf)
2014-05-23 11:56:23 +01:00
}
}
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
2014-05-17 16:20:56 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
2014-08-02 16:22:15 +01:00
rf.Fit(trainDataf)
fmt.Println(rf)
2014-08-02 16:22:15 +01:00
predictions := rf.Predict(testDataf)
fmt.Println(predictions)
2014-08-02 16:22:15 +01:00
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}