1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Reduced-error pruning

This commit is contained in:
Richard Townsend 2014-05-17 18:06:01 +01:00
parent c516907b13
commit 13c0dc3eba
4 changed files with 67 additions and 5 deletions

View File

@ -13,19 +13,17 @@ func TestRandomForest1(testEnv *testing.T) {
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewBinningFilter(insts[0], 10)
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(insts[0], 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
rf := NewRandomForest(10, 2)
rf := NewRandomForest(15, 2)
rf.Fit(insts[0])
predictions := rf.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
@ -131,6 +132,39 @@ func (d *DecisionTreeNode) String() string {
return d.getNestedString(0)
}
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
cf := eval.GetConfusionMatrix(from, predictions)
return eval.GetAccuracy(cf)
}
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using *base.Instances) {
// If you're a leaf, you're already pruned
if d.Children == nil {
return
} else {
// Recursively prune children of this node
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
for k := range d.Children {
d.Children[k].Prune(sub[k])
}
}
// Get a baseline accuracy
baselineAccuracy := computeAccuracy(d.Predict(using), using)
// Speculatively remove the children and re-evaluate
tmpChildren := d.Children
d.Children = nil
newAccuracy := computeAccuracy(d.Predict(using), using)
// Keep the children removed if better, else restore
if newAccuracy < baselineAccuracy {
d.Children = tmpChildren
}
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
outputAttrs := make([]base.Attribute, 1)
outputAttrs[0] = what.GetClassAttr()

View File

@ -63,3 +63,7 @@ func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
func (rt *RandomTree) String() string {
return fmt.Sprintf("RandomTree(%s)", rt.Root)
}
func (rt *RandomTree) Prune(with *base.Instances) {
rt.Root.Prune(with)
}

View File

@ -75,6 +75,32 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
fmt.Println(eval.GetSummary(confusionMat))
}
func TestPruning(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(insts[1])
filt.Run(insts[1])
filt.Run(insts[0])
root := NewRandomTree(2)
fitInsts := base.InstancesTrainTestSplit(insts[0], 0.6)
root.Fit(fitInsts[0])
root.Prune(fitInsts[1])
fmt.Println(root)
predictions := root.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
func TestInformationGain(testEnv *testing.T) {
outlook := make(map[string]map[string]int)
outlook["sunny"] = make(map[string]int)