1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

75 lines
1.7 KiB
Go
Raw Normal View History

// Demonstrates decision tree classification
package main
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
ensemble "github.com/sjwhitworth/golearn/ensemble"
trees "github.com/sjwhitworth/golearn/trees"
"math/rand"
"time"
)
func main () {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
// Create a 60-40 training-test split
insts := base.InstancesTrainTestSplit(iris, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(insts[0])
// Generate predictions
predictions := tree.Predict(insts[1])
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
}