1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/ensemble/randomforest_test.go

31 lines
833 B
Go
Raw Normal View History

package ensemble
import (
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/filters"
"testing"
)
func TestRandomForest1(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
2014-08-02 16:22:14 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
2014-05-17 20:37:19 +01:00
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
_ = eval.GetSummary(confusionMat)
}