2014-05-14 14:00:22 +01:00
|
|
|
package meta
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-05-14 14:00:22 +01:00
|
|
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/filters"
|
|
|
|
"github.com/sjwhitworth/golearn/trees"
|
2014-05-14 14:00:22 +01:00
|
|
|
"math/rand"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
2014-05-23 11:56:23 +01:00
|
|
|
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
2014-05-23 11:56:23 +01:00
|
|
|
rf := new(BaggedModel)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
|
|
rf.AddModel(trees.NewRandomTree(2))
|
|
|
|
}
|
|
|
|
testEnv.ResetTimer()
|
|
|
|
for i := 0; i < 20; i++ {
|
2014-08-02 16:22:15 +01:00
|
|
|
rf.Fit(instf)
|
2014-05-23 11:56:23 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
2014-05-23 11:56:23 +01:00
|
|
|
rf := new(BaggedModel)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
|
|
rf.AddModel(trees.NewRandomTree(2))
|
|
|
|
}
|
2014-08-02 16:22:15 +01:00
|
|
|
rf.Fit(instf)
|
2014-05-23 11:56:23 +01:00
|
|
|
testEnv.ResetTimer()
|
|
|
|
for i := 0; i < 20; i++ {
|
2014-08-02 16:22:15 +01:00
|
|
|
rf.Predict(instf)
|
2014-05-23 11:56:23 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-05-14 14:00:22 +01:00
|
|
|
func TestRandomForest1(testEnv *testing.T) {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
2014-06-06 20:30:24 +02:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
2014-05-17 16:20:56 +01:00
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
|
|
|
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
2014-05-14 14:00:22 +01:00
|
|
|
rf := new(BaggedModel)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
|
|
rf.AddModel(trees.NewRandomTree(2))
|
|
|
|
}
|
2014-08-02 16:22:15 +01:00
|
|
|
rf.Fit(trainDataf)
|
2014-05-14 14:00:22 +01:00
|
|
|
fmt.Println(rf)
|
2014-08-02 16:22:15 +01:00
|
|
|
predictions := rf.Predict(testDataf)
|
2014-05-14 14:00:22 +01:00
|
|
|
fmt.Println(predictions)
|
2014-08-02 16:22:15 +01:00
|
|
|
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
2014-05-14 14:00:22 +01:00
|
|
|
fmt.Println(confusionMat)
|
|
|
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
|
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
|
|
}
|