From 12ace9def5ea0c74d5548c56d467e09c17b6b30f Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sat, 17 May 2014 20:37:19 +0100 Subject: [PATCH] Identified source of the low accuracy --- ensemble/randomforest_test.go | 4 ++-- trees/entropy.go | 1 - trees/id3.go | 34 +++++++++++++++++++++++++++------- trees/random.go | 13 ++++++++++++- trees/tree_test.go | 26 +++++++++++++++++++++++++- 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 9ea8338..e82ec0b 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -13,13 +13,13 @@ func TestRandomForest1(testEnv *testing.T) { if err != nil { panic(err) } - insts := base.InstancesTrainTestSplit(inst, 0.6) + insts := base.InstancesTrainTestSplit(inst, 0.80) filt := filters.NewChiMergeFilter(insts[0], 0.90) filt.AddAllNumericAttributes() filt.Build() filt.Run(insts[1]) filt.Run(insts[0]) - rf := NewRandomForest(15, 2) + rf := NewRandomForest(10, 3) rf.Fit(insts[0]) predictions := rf.Predict(insts[1]) fmt.Println(predictions) diff --git a/trees/entropy.go b/trees/entropy.go index a2d4109..e7287fc 100644 --- a/trees/entropy.go +++ b/trees/entropy.go @@ -55,7 +55,6 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered } // Pick the one which maximises IG - return f.GetAttr(selectedAttribute) } diff --git a/trees/id3.go b/trees/id3.go index cd17e61..bfcc351 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -5,6 +5,7 @@ import ( "fmt" base "github.com/sjwhitworth/golearn/base" eval "github.com/sjwhitworth/golearn/evaluation" + "sort" ) // NodeType determines whether a DecisionTreeNode is a leaf or not @@ -66,9 +67,9 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { } } - // If there are no more attribute left to split on, + // If there are no more Attributes left to split on, // return a DecisionTreeLeaf with the majority class - if from.GetAttributeCount() == 1 { + if from.GetAttributeCount() == 2 { ret := &DecisionTreeNode{ LeafNode, nil, @@ -82,7 +83,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { ret := &DecisionTreeNode{ RuleNode, - make(map[string]*DecisionTreeNode), + nil, nil, classes, maxClass, @@ -92,9 +93,14 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { // Generate a return structure // Generate the splitting attribute splitOnAttribute := with.GenerateSplitAttribute(from) + if splitOnAttribute == nil { + // Can't determine, just return what we have + return ret + } // Split the attributes based on this attribute's value splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute) // Create new children from these attributes + ret.Children = make(map[string]*DecisionTreeNode) for k := range splitInstances { newInstances := splitInstances[k] ret.Children[k] = InferID3Tree(newInstances, with) @@ -114,7 +120,12 @@ func (d *DecisionTreeNode) getNestedString(level int) string { buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class)) } else { buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName())) + keys := make([]string, 0) for k := range d.Children { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { buf.WriteString("\n") buf.WriteString(tmp.String()) buf.WriteString("\t") @@ -171,16 +182,25 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { predictions := base.NewInstances(outputAttrs, what.Rows) for i := 0; i < what.Rows; i++ { cur := d - for j := 0; j < what.Cols; j++ { - at := what.GetAttr(j) - classVar := at.GetStringFromSysVal(what.Get(i, j)) + for { if cur.Children == nil { predictions.SetAttrStr(i, 0, cur.Class) + break } else { + at := cur.SplitAttr + j := what.GetAttrIndex(at) + classVar := at.GetStringFromSysVal(what.Get(i, j)) if next, ok := cur.Children[classVar]; ok { cur = next } else { - predictions.SetAttrStr(i, 0, cur.Class) + var bestChild string + for c := range cur.Children { + bestChild = c + if c > classVar { + break + } + } + cur = cur.Children[bestChild] } } } diff --git a/trees/random.go b/trees/random.go index cac0147..02697ca 100644 --- a/trees/random.go +++ b/trees/random.go @@ -20,11 +20,22 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base consideredAttributes := make([]int, r.Attributes) attrCounter := 0 for { - if attrCounter >= r.Attributes { + if len(consideredAttributes) >= r.Attributes { break } selectedAttribute := rand.Intn(maximumAttribute) + fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes)) if selectedAttribute != f.ClassIndex { + matched := false + for _, a := range consideredAttributes { + if a == selectedAttribute { + matched = true + break + } + } + if matched { + continue + } consideredAttributes = append(consideredAttributes, selectedAttribute) attrCounter++ } diff --git a/trees/tree_test.go b/trees/tree_test.go index d8b7451..5cae22e 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -56,7 +56,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) { if err != nil { panic(err) } - insts := base.InstancesTrainTestSplit(inst, 0.6) + insts := base.InstancesTrainTestSplit(inst, 0.4) filt := filters.NewChiMergeFilter(inst, 0.90) filt.AddAllNumericAttributes() filt.Build() @@ -170,3 +170,27 @@ func TestID3Inference(testEnv *testing.T) { testEnv.Error(overcastChild) } } + +func TestID3Classification(testEnv *testing.T) { + inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + filt := filters.NewBinningFilter(inst, 10) + filt.AddAllNumericAttributes() + filt.Build() + filt.Run(inst) + fmt.Println(inst) + insts := base.InstancesTrainTestSplit(inst, 0.70) + // Build the decision tree + rule := new(InformationGainRuleGenerator) + root := InferID3Tree(insts[0], rule) + fmt.Println(root) + predictions := root.Predict(insts[1]) + fmt.Println(predictions) + confusionMat := eval.GetConfusionMatrix(insts[1], predictions) + fmt.Println(confusionMat) + fmt.Println(eval.GetMacroPrecision(confusionMat)) + fmt.Println(eval.GetMacroRecall(confusionMat)) + fmt.Println(eval.GetSummary(confusionMat)) +}