mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
149 lines
3.4 KiB
Go
149 lines
3.4 KiB
Go
// Demonstrates decision tree classification
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/ensemble"
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
"github.com/sjwhitworth/golearn/filters"
|
|
"github.com/sjwhitworth/golearn/trees"
|
|
"math/rand"
|
|
)
|
|
|
|
func main() {
|
|
|
|
var tree base.Classifier
|
|
|
|
rand.Seed(44111342)
|
|
|
|
// Load in the iris dataset
|
|
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Discretise the iris dataset with Chi-Merge
|
|
filt := filters.NewChiMergeFilter(iris, 0.999)
|
|
for _, a := range base.NonClassFloatAttributes(iris) {
|
|
filt.AddAttribute(a)
|
|
}
|
|
filt.Train()
|
|
irisf := base.NewLazilyFilteredInstances(iris, filt)
|
|
|
|
// Create a 60-40 training-test split
|
|
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
|
|
|
|
//
|
|
// First up, use ID3
|
|
//
|
|
tree = trees.NewID3DecisionTree(0.6)
|
|
// (Parameter controls train-prune split.)
|
|
|
|
// Train the ID3 tree
|
|
err = tree.Fit(trainData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Generate predictions
|
|
predictions, err := tree.Predict(testData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Evaluate
|
|
fmt.Println("ID3 Performance (information gain)")
|
|
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
}
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
|
|
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
|
|
// (Parameter controls train-prune split.)
|
|
|
|
// Train the ID3 tree
|
|
err = tree.Fit(trainData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Generate predictions
|
|
predictions, err = tree.Predict(testData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Evaluate
|
|
fmt.Println("ID3 Performance (information gain ratio)")
|
|
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
}
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
|
|
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
|
|
// (Parameter controls train-prune split.)
|
|
|
|
// Train the ID3 tree
|
|
err = tree.Fit(trainData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Generate predictions
|
|
predictions, err = tree.Predict(testData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Evaluate
|
|
fmt.Println("ID3 Performance (gini index generator)")
|
|
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
}
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
//
|
|
// Next up, Random Trees
|
|
//
|
|
|
|
// Consider two randomly-chosen attributes
|
|
tree = trees.NewRandomTree(2)
|
|
err = tree.Fit(trainData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
predictions, err = tree.Predict(testData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println("RandomTree Performance")
|
|
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
}
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
|
|
//
|
|
// Finally, Random Forests
|
|
//
|
|
tree = ensemble.NewRandomForest(70, 3)
|
|
err = tree.Fit(trainData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
predictions, err = tree.Predict(testData)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println("RandomForest Performance")
|
|
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
}
|
|
fmt.Println(evaluation.GetSummary(cf))
|
|
}
|