mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Reduced-error pruning
This commit is contained in:
parent
c516907b13
commit
13c0dc3eba
@ -13,19 +13,17 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
filt := filters.NewBinningFilter(insts[0], 10)
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(insts[0], 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(insts[1])
|
||||
filt.Run(insts[0])
|
||||
rf := NewRandomForest(10, 2)
|
||||
rf := NewRandomForest(15, 2)
|
||||
rf.Fit(insts[0])
|
||||
predictions := rf.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
34
trees/id3.go
34
trees/id3.go
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
||||
@ -131,6 +132,39 @@ func (d *DecisionTreeNode) String() string {
|
||||
return d.getNestedString(0)
|
||||
}
|
||||
|
||||
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
||||
cf := eval.GetConfusionMatrix(from, predictions)
|
||||
return eval.GetAccuracy(cf)
|
||||
}
|
||||
|
||||
// Prune eliminates branches which hurt accuracy
|
||||
func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
// If you're a leaf, you're already pruned
|
||||
if d.Children == nil {
|
||||
return
|
||||
} else {
|
||||
// Recursively prune children of this node
|
||||
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
|
||||
for k := range d.Children {
|
||||
d.Children[k].Prune(sub[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Get a baseline accuracy
|
||||
baselineAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
|
||||
// Speculatively remove the children and re-evaluate
|
||||
tmpChildren := d.Children
|
||||
d.Children = nil
|
||||
newAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
|
||||
// Keep the children removed if better, else restore
|
||||
if newAccuracy < baselineAccuracy {
|
||||
d.Children = tmpChildren
|
||||
}
|
||||
}
|
||||
|
||||
// Predict outputs a base.Instances containing predictions from this tree
|
||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
outputAttrs := make([]base.Attribute, 1)
|
||||
outputAttrs[0] = what.GetClassAttr()
|
||||
|
@ -63,3 +63,7 @@ func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
|
||||
func (rt *RandomTree) String() string {
|
||||
return fmt.Sprintf("RandomTree(%s)", rt.Root)
|
||||
}
|
||||
|
||||
func (rt *RandomTree) Prune(with *base.Instances) {
|
||||
rt.Root.Prune(with)
|
||||
}
|
||||
|
@ -75,6 +75,32 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestPruning(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(insts[1])
|
||||
filt.Run(insts[1])
|
||||
filt.Run(insts[0])
|
||||
root := NewRandomTree(2)
|
||||
fitInsts := base.InstancesTrainTestSplit(insts[0], 0.6)
|
||||
root.Fit(fitInsts[0])
|
||||
root.Prune(fitInsts[1])
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestInformationGain(testEnv *testing.T) {
|
||||
outlook := make(map[string]map[string]int)
|
||||
outlook["sunny"] = make(map[string]int)
|
||||
|
Loading…
x
Reference in New Issue
Block a user