1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/ensemble/randomforest_test.go

30 lines
792 B
Go
Raw Normal View History

package ensemble
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
"testing"
)
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.60)
2014-05-17 18:06:01 +01:00
filt := filters.NewChiMergeFilter(insts[0], 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
2014-05-17 20:37:19 +01:00
rf := NewRandomForest(10, 3)
2014-05-17 17:35:10 +01:00
rf.Fit(insts[0])
predictions := rf.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}