diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index b8dac72..9ea8338 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -13,19 +13,17 @@ func TestRandomForest1(testEnv *testing.T) { if err != nil { panic(err) } - insts := base.InstancesTrainTestSplit(inst, 0.4) - filt := filters.NewBinningFilter(insts[0], 10) + insts := base.InstancesTrainTestSplit(inst, 0.6) + filt := filters.NewChiMergeFilter(insts[0], 0.90) filt.AddAllNumericAttributes() filt.Build() filt.Run(insts[1]) filt.Run(insts[0]) - rf := NewRandomForest(10, 2) + rf := NewRandomForest(15, 2) rf.Fit(insts[0]) predictions := rf.Predict(insts[1]) fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(insts[1], predictions) fmt.Println(confusionMat) - fmt.Println(eval.GetMacroPrecision(confusionMat)) - fmt.Println(eval.GetMacroRecall(confusionMat)) fmt.Println(eval.GetSummary(confusionMat)) } diff --git a/trees/id3.go b/trees/id3.go index 6cbe662..cd17e61 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" base "github.com/sjwhitworth/golearn/base" + eval "github.com/sjwhitworth/golearn/evaluation" ) // NodeType determines whether a DecisionTreeNode is a leaf or not @@ -131,6 +132,39 @@ func (d *DecisionTreeNode) String() string { return d.getNestedString(0) } +func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 { + cf := eval.GetConfusionMatrix(from, predictions) + return eval.GetAccuracy(cf) +} + +// Prune eliminates branches which hurt accuracy +func (d *DecisionTreeNode) Prune(using *base.Instances) { + // If you're a leaf, you're already pruned + if d.Children == nil { + return + } else { + // Recursively prune children of this node + sub := using.DecomposeOnAttributeValues(d.SplitAttr) + for k := range d.Children { + d.Children[k].Prune(sub[k]) + } + } + + // Get a baseline accuracy + baselineAccuracy := computeAccuracy(d.Predict(using), using) + + // Speculatively remove the children and re-evaluate + tmpChildren := d.Children + d.Children = nil + newAccuracy := computeAccuracy(d.Predict(using), using) + + // Keep the children removed if better, else restore + if newAccuracy < baselineAccuracy { + d.Children = tmpChildren + } +} + +// Predict outputs a base.Instances containing predictions from this tree func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { outputAttrs := make([]base.Attribute, 1) outputAttrs[0] = what.GetClassAttr() diff --git a/trees/random.go b/trees/random.go index 0c8c2bf..cac0147 100644 --- a/trees/random.go +++ b/trees/random.go @@ -63,3 +63,7 @@ func (rt *RandomTree) Predict(from *base.Instances) *base.Instances { func (rt *RandomTree) String() string { return fmt.Sprintf("RandomTree(%s)", rt.Root) } + +func (rt *RandomTree) Prune(with *base.Instances) { + rt.Root.Prune(with) +} diff --git a/trees/tree_test.go b/trees/tree_test.go index e12495b..d8b7451 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -75,6 +75,32 @@ func TestRandomTreeClassification2(testEnv *testing.T) { fmt.Println(eval.GetSummary(confusionMat)) } +func TestPruning(testEnv *testing.T) { + inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + insts := base.InstancesTrainTestSplit(inst, 0.6) + filt := filters.NewChiMergeFilter(inst, 0.90) + filt.AddAllNumericAttributes() + filt.Build() + fmt.Println(insts[1]) + filt.Run(insts[1]) + filt.Run(insts[0]) + root := NewRandomTree(2) + fitInsts := base.InstancesTrainTestSplit(insts[0], 0.6) + root.Fit(fitInsts[0]) + root.Prune(fitInsts[1]) + fmt.Println(root) + predictions := root.Predict(insts[1]) + fmt.Println(predictions) + confusionMat := eval.GetConfusionMatrix(insts[1], predictions) + fmt.Println(confusionMat) + fmt.Println(eval.GetMacroPrecision(confusionMat)) + fmt.Println(eval.GetMacroRecall(confusionMat)) + fmt.Println(eval.GetSummary(confusionMat)) +} + func TestInformationGain(testEnv *testing.T) { outlook := make(map[string]map[string]int) outlook["sunny"] = make(map[string]int)