mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
251 lines
6.7 KiB
Go
251 lines
6.7 KiB
Go
package trees
|
|
|
|
import (
|
|
"fmt"
|
|
base "github.com/sjwhitworth/golearn/base"
|
|
eval "github.com/sjwhitworth/golearn/evaluation"
|
|
filters "github.com/sjwhitworth/golearn/filters"
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
func TestRandomTree(testEnv *testing.T) {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
filt.AddAllNumericAttributes()
|
|
filt.Build()
|
|
filt.Run(inst)
|
|
fmt.Println(inst)
|
|
r := new(RandomTreeRuleGenerator)
|
|
r.Attributes = 2
|
|
root := InferID3Tree(inst, r)
|
|
fmt.Println(root)
|
|
}
|
|
|
|
func TestRandomTreeClassification(testEnv *testing.T) {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
filt.AddAllNumericAttributes()
|
|
filt.Build()
|
|
filt.Run(insts[0])
|
|
filt.Run(insts[1])
|
|
fmt.Println(inst)
|
|
r := new(RandomTreeRuleGenerator)
|
|
r.Attributes = 2
|
|
root := InferID3Tree(insts[0], r)
|
|
fmt.Println(root)
|
|
predictions := root.Predict(insts[1])
|
|
fmt.Println(predictions)
|
|
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
|
fmt.Println(confusionMat)
|
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
}
|
|
|
|
func TestRandomTreeClassification2(testEnv *testing.T) {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
insts := base.InstancesTrainTestSplit(inst, 0.4)
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
filt.AddAllNumericAttributes()
|
|
filt.Build()
|
|
fmt.Println(insts[1])
|
|
filt.Run(insts[1])
|
|
filt.Run(insts[0])
|
|
root := NewRandomTree(2)
|
|
root.Fit(insts[0])
|
|
fmt.Println(root)
|
|
predictions := root.Predict(insts[1])
|
|
fmt.Println(predictions)
|
|
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
|
fmt.Println(confusionMat)
|
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
}
|
|
|
|
func TestPruning(testEnv *testing.T) {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
filt.AddAllNumericAttributes()
|
|
filt.Build()
|
|
fmt.Println(insts[1])
|
|
filt.Run(insts[1])
|
|
filt.Run(insts[0])
|
|
root := NewRandomTree(2)
|
|
fitInsts := base.InstancesTrainTestSplit(insts[0], 0.6)
|
|
root.Fit(fitInsts[0])
|
|
root.Prune(fitInsts[1])
|
|
fmt.Println(root)
|
|
predictions := root.Predict(insts[1])
|
|
fmt.Println(predictions)
|
|
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
|
fmt.Println(confusionMat)
|
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
}
|
|
|
|
func TestInformationGain(testEnv *testing.T) {
|
|
outlook := make(map[string]map[string]int)
|
|
outlook["sunny"] = make(map[string]int)
|
|
outlook["overcast"] = make(map[string]int)
|
|
outlook["rain"] = make(map[string]int)
|
|
outlook["sunny"]["play"] = 2
|
|
outlook["sunny"]["noplay"] = 3
|
|
outlook["overcast"]["play"] = 4
|
|
outlook["rain"]["play"] = 3
|
|
outlook["rain"]["noplay"] = 2
|
|
|
|
entropy := getSplitEntropy(outlook)
|
|
if math.Abs(entropy-0.694) > 0.001 {
|
|
testEnv.Error(entropy)
|
|
}
|
|
}
|
|
|
|
func TestID3Inference(testEnv *testing.T) {
|
|
|
|
// Import the "PlayTennis" dataset
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Build the decision tree
|
|
rule := new(InformationGainRuleGenerator)
|
|
root := InferID3Tree(inst, rule)
|
|
|
|
// Verify the tree
|
|
// First attribute should be "outlook"
|
|
if root.SplitAttr.GetName() != "outlook" {
|
|
testEnv.Error(root)
|
|
}
|
|
sunnyChild := root.Children["sunny"]
|
|
overcastChild := root.Children["overcast"]
|
|
rainyChild := root.Children["rainy"]
|
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
|
testEnv.Error(sunnyChild)
|
|
}
|
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
|
testEnv.Error(rainyChild)
|
|
}
|
|
if overcastChild.SplitAttr != nil {
|
|
testEnv.Error(overcastChild)
|
|
}
|
|
|
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
|
if sunnyLeafHigh.Class != "no" {
|
|
testEnv.Error(sunnyLeafHigh)
|
|
}
|
|
if sunnyLeafNormal.Class != "yes" {
|
|
testEnv.Error(sunnyLeafNormal)
|
|
}
|
|
|
|
windyLeafFalse := rainyChild.Children["false"]
|
|
windyLeafTrue := rainyChild.Children["true"]
|
|
if windyLeafFalse.Class != "yes" {
|
|
testEnv.Error(windyLeafFalse)
|
|
}
|
|
if windyLeafTrue.Class != "no" {
|
|
testEnv.Error(windyLeafTrue)
|
|
}
|
|
|
|
if overcastChild.Class != "yes" {
|
|
testEnv.Error(overcastChild)
|
|
}
|
|
}
|
|
|
|
func TestID3Classification(testEnv *testing.T) {
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
filt := filters.NewBinningFilter(inst, 10)
|
|
filt.AddAllNumericAttributes()
|
|
filt.Build()
|
|
filt.Run(inst)
|
|
fmt.Println(inst)
|
|
insts := base.InstancesTrainTestSplit(inst, 0.70)
|
|
// Build the decision tree
|
|
rule := new(InformationGainRuleGenerator)
|
|
root := InferID3Tree(insts[0], rule)
|
|
fmt.Println(root)
|
|
predictions := root.Predict(insts[1])
|
|
fmt.Println(predictions)
|
|
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
|
fmt.Println(confusionMat)
|
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
|
fmt.Println(eval.GetSummary(confusionMat))
|
|
}
|
|
|
|
func TestID3(testEnv *testing.T) {
|
|
|
|
// Import the "PlayTennis" dataset
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Build the decision tree
|
|
tree := NewID3DecisionTree(0.0)
|
|
tree.Fit(inst)
|
|
root := tree.Root
|
|
|
|
// Verify the tree
|
|
// First attribute should be "outlook"
|
|
if root.SplitAttr.GetName() != "outlook" {
|
|
testEnv.Error(root)
|
|
}
|
|
sunnyChild := root.Children["sunny"]
|
|
overcastChild := root.Children["overcast"]
|
|
rainyChild := root.Children["rainy"]
|
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
|
testEnv.Error(sunnyChild)
|
|
}
|
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
|
testEnv.Error(rainyChild)
|
|
}
|
|
if overcastChild.SplitAttr != nil {
|
|
testEnv.Error(overcastChild)
|
|
}
|
|
|
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
|
if sunnyLeafHigh.Class != "no" {
|
|
testEnv.Error(sunnyLeafHigh)
|
|
}
|
|
if sunnyLeafNormal.Class != "yes" {
|
|
testEnv.Error(sunnyLeafNormal)
|
|
}
|
|
|
|
windyLeafFalse := rainyChild.Children["false"]
|
|
windyLeafTrue := rainyChild.Children["true"]
|
|
if windyLeafFalse.Class != "yes" {
|
|
testEnv.Error(windyLeafFalse)
|
|
}
|
|
if windyLeafTrue.Class != "no" {
|
|
testEnv.Error(windyLeafTrue)
|
|
}
|
|
|
|
if overcastChild.Class != "yes" {
|
|
testEnv.Error(overcastChild)
|
|
}
|
|
}
|