2014-05-14 14:00:22 +01:00
|
|
|
package ensemble
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
base "github.com/sjwhitworth/golearn/base"
|
|
|
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
|
|
|
filters "github.com/sjwhitworth/golearn/filters"
|
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestRandomForest1(testEnv *testing.T) {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
2014-05-18 11:23:32 +01:00
|
|
|
insts := base.InstancesTrainTestSplit(inst, 0.60)
|
2014-05-17 18:06:01 +01:00
|
|
|
filt := filters.NewChiMergeFilter(insts[0], 0.90)
|
2014-05-14 14:00:22 +01:00
|
|
|
filt.AddAllNumericAttributes()
|
|
|
|
filt.Build()
|
|
|
|
filt.Run(insts[1])
|
|
|
|
filt.Run(insts[0])
|
2014-05-17 20:37:19 +01:00
|
|
|
rf := NewRandomForest(10, 3)
|
2014-05-17 17:35:10 +01:00
|
|
|
rf.Fit(insts[0])
|
2014-05-14 14:00:22 +01:00
|
|
|
predictions := rf.Predict(insts[1])
|
|
|
|
fmt.Println(predictions)
|
|
|
|
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
|
|
|
fmt.Println(confusionMat)
|
|
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
|
|
}
|