From db3ac3c695929f68528c3b04f555727aa716d630 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sat, 17 May 2014 17:28:51 +0100 Subject: [PATCH] ID3 algorithm working --- trees/entropy.go | 101 ++++++++++++++++++++++++++++++++++++++ trees/{tree.go => id3.go} | 44 +++++++++++------ trees/random.go | 59 ++-------------------- trees/tennis.csv | 15 ++++++ trees/tree_test.go | 57 ++++++++++++++++++++- 5 files changed, 205 insertions(+), 71 deletions(-) create mode 100644 trees/entropy.go rename trees/{tree.go => id3.go} (77%) create mode 100644 trees/tennis.csv diff --git a/trees/entropy.go b/trees/entropy.go new file mode 100644 index 0000000..a2d4109 --- /dev/null +++ b/trees/entropy.go @@ -0,0 +1,101 @@ +package trees + +import ( + base "github.com/sjwhitworth/golearn/base" + "math" +) + +// +// Information gain rule generator +// + +type InformationGainRuleGenerator struct { +} + +// GetSplitAttribute returns the non-class Attribute which maximises the +// information gain. +// +// IMPORTANT: passing a base.Instances with no Attributes other than the class +// variable will panic() +func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute { + allAttributes := make([]int, 0) + for i := 0; i < f.Cols; i++ { + if i != f.ClassIndex { + allAttributes = append(allAttributes, i) + } + } + return r.GetSplitAttributeFromSelection(allAttributes, f) +} + +// GetSplitAttribute from selection returns the class Attribute which maximises +// the information gain amongst consideredAttributes +// +// IMPORTANT: passing a zero-length consideredAttributes parameter will panic() +func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute { + + // Next step is to compute the information gain at this node + // for each randomly chosen attribute, and pick the one + // which maximises it + maxGain := math.Inf(-1) + selectedAttribute := -1 + + // Compute the base entropy + classDist := f.GetClassDistribution() + baseEntropy := getBaseEntropy(classDist) + + // Compute the information gain for each attribute + for _, s := range consideredAttributes { + proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s)) + localEntropy := getSplitEntropy(proposedClassDist) + informationGain := baseEntropy - localEntropy + if informationGain > maxGain { + maxGain = informationGain + selectedAttribute = s + } + } + + // Pick the one which maximises IG + + return f.GetAttr(selectedAttribute) +} + +// +// Entropy functions +// + +// getSplitEntropy determines the entropy of the target +// class distribution after splitting on an base.Attribute +func getSplitEntropy(s map[string]map[string]int) float64 { + ret := 0.0 + count := 0 + for a := range s { + for c := range s[a] { + count += s[a][c] + } + } + for a := range s { + total := 0.0 + for c := range s[a] { + total += float64(s[a][c]) + } + for c := range s[a] { + ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2) + } + ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2) + } + return ret +} + +// getBaseEntropy determines the entropy of the target +// class distribution before splitting on an base.Attribute +func getBaseEntropy(s map[string]int) float64 { + ret := 0.0 + count := 0 + for k := range s { + count += s[k] + } + for k := range s { + ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2) + } + return ret +} diff --git a/trees/tree.go b/trees/id3.go similarity index 77% rename from trees/tree.go rename to trees/id3.go index ac0a7b4..6cbe662 100644 --- a/trees/tree.go +++ b/trees/id3.go @@ -1,9 +1,9 @@ package trees import ( + "bytes" "fmt" base "github.com/sjwhitworth/golearn/base" - "strings" ) // NodeType determines whether a DecisionTreeNode is a leaf or not @@ -32,9 +32,9 @@ type DecisionTreeNode struct { ClassAttr *base.Attribute } -// InferDecisionTree builds a decision tree using a RuleGenerator -// from a set of Instances -func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { +// InferID3Tree builds a decision tree using a RuleGenerator +// from a set of Instances (implements the ID3 algorithm) +func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { // Count the number of classes at this node classes := from.CountClassValues() // If there's only one class, return a DecisionTreeLeaf with @@ -96,25 +96,39 @@ func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNo // Create new children from these attributes for k := range splitInstances { newInstances := splitInstances[k] - ret.Children[k] = InferDecisionTree(newInstances, with) + ret.Children[k] = InferID3Tree(newInstances, with) } ret.SplitAttr = splitOnAttribute return ret } +func (d *DecisionTreeNode) getNestedString(level int) string { + buf := bytes.NewBuffer(nil) + tmp := bytes.NewBuffer(nil) + for i := 0; i < level; i++ { + tmp.WriteString("\t") + } + buf.WriteString(tmp.String()) + if d.Children == nil { + buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class)) + } else { + buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName())) + for k := range d.Children { + buf.WriteString("\n") + buf.WriteString(tmp.String()) + buf.WriteString("\t") + buf.WriteString(k) + buf.WriteString("\n") + buf.WriteString(d.Children[k].getNestedString(level + 1)) + } + } + return buf.String() +} + // String returns a human-readable representation of a given node // and it's children func (d *DecisionTreeNode) String() string { - children := make([]string, 0) - if d.Children != nil { - for k := range d.Children { - childStr := fmt.Sprintf("Rule(%s -> %s)", k, d.Children[k]) - children = append(children, childStr) - } - return fmt.Sprintf("(%s(%s))", d.SplitAttr, strings.Join(children, "\n\t")) - } - - return fmt.Sprintf("Leaf(%s (%s))", d.Class, d.ClassDist) + return d.getNestedString(0) } func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { diff --git a/trees/random.go b/trees/random.go index 40097e5..0ab9539 100644 --- a/trees/random.go +++ b/trees/random.go @@ -3,47 +3,18 @@ package trees import ( "fmt" base "github.com/sjwhitworth/golearn/base" - "math" "math/rand" ) type RandomTreeRuleGenerator struct { - Attributes int -} - -func getSplitEntropy(s map[string]map[string]int) float64 { - ret := 0.0 - count := 0 - for a := range s { - total := 0.0 - for c := range s[a] { - ret -= float64(s[a][c]) * math.Log(float64(s[a][c])) / math.Log(2) - total += float64(s[a][c]) - count += s[a][c] - } - ret += total * math.Log(total) / math.Log(2) - } - return ret / float64(count) -} - -func getBaseEntropy(s map[string]int) float64 { - ret := 0.0 - count := 0 - for k := range s { - count += s[k] - } - for k := range s { - ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2) - } - return ret + Attributes int + internalRule InformationGainRuleGenerator } // So WEKA returns a couple of possible attributes and evaluates // the split criteria on each func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute { - fmt.Println("GenerateSplitAttribute", r.Attributes) - // First step is to generate the random attributes that we'll consider maximumAttribute := f.GetAttributeCount() consideredAttributes := make([]int, r.Attributes) @@ -59,28 +30,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base } } - // Next step is to compute the information gain at this node - // for each randomly chosen attribute, and pick the one - // which maximises it - maxGain := math.Inf(-1) - selectedAttribute := -1 - - // Compute the base entropy - classDist := f.GetClassDistribution() - baseEntropy := getBaseEntropy(classDist) - - for _, s := range consideredAttributes { - proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s)) - localEntropy := getSplitEntropy(proposedClassDist) - informationGain := baseEntropy - localEntropy - if informationGain > maxGain { - maxGain = localEntropy - selectedAttribute = s - fmt.Printf("Gain: %.4f, selectedAttribute: %s\n", informationGain, f.GetAttr(selectedAttribute)) - } - } - - return f.GetAttr(selectedAttribute) + return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f) } type RandomTree struct { @@ -95,13 +45,14 @@ func NewRandomTree(attrs int) *RandomTree { nil, RandomTreeRuleGenerator{ attrs, + InformationGainRuleGenerator{}, }, } } // Train builds a RandomTree suitable for prediction func (rt *RandomTree) Fit(from *base.Instances) { - rt.Root = InferDecisionTree(from, &rt.Rule) + rt.Root = InferID3Tree(from, &rt.Rule) } // Predict returns a set of Instances containing predictions diff --git a/trees/tennis.csv b/trees/tennis.csv new file mode 100644 index 0000000..f83b6b8 --- /dev/null +++ b/trees/tennis.csv @@ -0,0 +1,15 @@ +outlook,temp,humidity,windy,play +sunny,hot,high,false,no +sunny,hot,high,true,no +overcast,hot,high,false,yes +rainy,mild,high,false,yes +rainy,cool,normal,false,yes +rainy,cool,normal,true,no +overcast,cool,normal,true,yes +sunny,mild,high,false,no +sunny,cool,normal,false,yes +rainy,mild,normal,false,yes +sunny,mild,normal,true,yes +overcast,mild,high,true,yes +overcast,hot,normal,false,yes +rainy,mild,high,true,no diff --git a/trees/tree_test.go b/trees/tree_test.go index a743151..e12495b 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -22,7 +22,7 @@ func TestRandomTree(testEnv *testing.T) { fmt.Println(inst) r := new(RandomTreeRuleGenerator) r.Attributes = 2 - root := InferDecisionTree(inst, r) + root := InferID3Tree(inst, r) fmt.Println(root) } @@ -40,7 +40,7 @@ func TestRandomTreeClassification(testEnv *testing.T) { fmt.Println(inst) r := new(RandomTreeRuleGenerator) r.Attributes = 2 - root := InferDecisionTree(insts[0], r) + root := InferID3Tree(insts[0], r) fmt.Println(root) predictions := root.Predict(insts[1]) fmt.Println(predictions) @@ -91,3 +91,56 @@ func TestInformationGain(testEnv *testing.T) { testEnv.Error(entropy) } } + +func TestID3Inference(testEnv *testing.T) { + + // Import the "PlayTennis" dataset + inst, err := base.ParseCSVToInstances("./tennis.csv", true) + if err != nil { + panic(err) + } + + // Build the decision tree + rule := new(InformationGainRuleGenerator) + root := InferID3Tree(inst, rule) + + // Verify the tree + // First attribute should be "outlook" + if root.SplitAttr.GetName() != "outlook" { + testEnv.Error(root) + } + sunnyChild := root.Children["sunny"] + overcastChild := root.Children["overcast"] + rainyChild := root.Children["rainy"] + if sunnyChild.SplitAttr.GetName() != "humidity" { + testEnv.Error(sunnyChild) + } + if rainyChild.SplitAttr.GetName() != "windy" { + testEnv.Error(rainyChild) + } + if overcastChild.SplitAttr != nil { + testEnv.Error(overcastChild) + } + + sunnyLeafHigh := sunnyChild.Children["high"] + sunnyLeafNormal := sunnyChild.Children["normal"] + if sunnyLeafHigh.Class != "no" { + testEnv.Error(sunnyLeafHigh) + } + if sunnyLeafNormal.Class != "yes" { + testEnv.Error(sunnyLeafNormal) + } + + windyLeafFalse := rainyChild.Children["false"] + windyLeafTrue := rainyChild.Children["true"] + if windyLeafFalse.Class != "yes" { + testEnv.Error(windyLeafFalse) + } + if windyLeafTrue.Class != "no" { + testEnv.Error(windyLeafTrue) + } + + if overcastChild.Class != "yes" { + testEnv.Error(overcastChild) + } +}