1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

ID3 algorithm working

This commit is contained in:
Richard Townsend 2014-05-17 17:28:51 +01:00
parent cf165695c8
commit db3ac3c695
5 changed files with 205 additions and 71 deletions

101
trees/entropy.go Normal file
View File

@ -0,0 +1,101 @@
package trees
import (
base "github.com/sjwhitworth/golearn/base"
"math"
)
//
// Information gain rule generator
//
type InformationGainRuleGenerator struct {
}
// GetSplitAttribute returns the non-class Attribute which maximises the
// information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
allAttributes := make([]int, 0)
for i := 0; i < f.Cols; i++ {
if i != f.ClassIndex {
allAttributes = append(allAttributes, i)
}
}
return r.GetSplitAttributeFromSelection(allAttributes, f)
}
// GetSplitAttribute from selection returns the class Attribute which maximises
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedAttribute := -1
// Compute the base entropy
classDist := f.GetClassDistribution()
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
maxGain = informationGain
selectedAttribute = s
}
}
// Pick the one which maximises IG
return f.GetAttr(selectedAttribute)
}
//
// Entropy functions
//
// getSplitEntropy determines the entropy of the target
// class distribution after splitting on an base.Attribute
func getSplitEntropy(s map[string]map[string]int) float64 {
ret := 0.0
count := 0
for a := range s {
for c := range s[a] {
count += s[a][c]
}
}
for a := range s {
total := 0.0
for c := range s[a] {
total += float64(s[a][c])
}
for c := range s[a] {
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
}
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
}
return ret
}
// getBaseEntropy determines the entropy of the target
// class distribution before splitting on an base.Attribute
func getBaseEntropy(s map[string]int) float64 {
ret := 0.0
count := 0
for k := range s {
count += s[k]
}
for k := range s {
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
}
return ret
}

View File

@ -1,9 +1,9 @@
package trees
import (
"bytes"
"fmt"
base "github.com/sjwhitworth/golearn/base"
"strings"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
@ -32,9 +32,9 @@ type DecisionTreeNode struct {
ClassAttr *base.Attribute
}
// InferDecisionTree builds a decision tree using a RuleGenerator
// from a set of Instances
func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := from.CountClassValues()
// If there's only one class, return a DecisionTreeLeaf with
@ -96,25 +96,39 @@ func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNo
// Create new children from these attributes
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferDecisionTree(newInstances, with)
ret.Children[k] = InferID3Tree(newInstances, with)
}
ret.SplitAttr = splitOnAttribute
return ret
}
func (d *DecisionTreeNode) getNestedString(level int) string {
buf := bytes.NewBuffer(nil)
tmp := bytes.NewBuffer(nil)
for i := 0; i < level; i++ {
tmp.WriteString("\t")
}
buf.WriteString(tmp.String())
if d.Children == nil {
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
} else {
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
for k := range d.Children {
buf.WriteString("\n")
buf.WriteString(tmp.String())
buf.WriteString("\t")
buf.WriteString(k)
buf.WriteString("\n")
buf.WriteString(d.Children[k].getNestedString(level + 1))
}
}
return buf.String()
}
// String returns a human-readable representation of a given node
// and it's children
func (d *DecisionTreeNode) String() string {
children := make([]string, 0)
if d.Children != nil {
for k := range d.Children {
childStr := fmt.Sprintf("Rule(%s -> %s)", k, d.Children[k])
children = append(children, childStr)
}
return fmt.Sprintf("(%s(%s))", d.SplitAttr, strings.Join(children, "\n\t"))
}
return fmt.Sprintf("Leaf(%s (%s))", d.Class, d.ClassDist)
return d.getNestedString(0)
}
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {

View File

@ -3,47 +3,18 @@ package trees
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math"
"math/rand"
)
type RandomTreeRuleGenerator struct {
Attributes int
}
func getSplitEntropy(s map[string]map[string]int) float64 {
ret := 0.0
count := 0
for a := range s {
total := 0.0
for c := range s[a] {
ret -= float64(s[a][c]) * math.Log(float64(s[a][c])) / math.Log(2)
total += float64(s[a][c])
count += s[a][c]
}
ret += total * math.Log(total) / math.Log(2)
}
return ret / float64(count)
}
func getBaseEntropy(s map[string]int) float64 {
ret := 0.0
count := 0
for k := range s {
count += s[k]
}
for k := range s {
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
}
return ret
Attributes int
internalRule InformationGainRuleGenerator
}
// So WEKA returns a couple of possible attributes and evaluates
// the split criteria on each
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
fmt.Println("GenerateSplitAttribute", r.Attributes)
// First step is to generate the random attributes that we'll consider
maximumAttribute := f.GetAttributeCount()
consideredAttributes := make([]int, r.Attributes)
@ -59,28 +30,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
}
}
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedAttribute := -1
// Compute the base entropy
classDist := f.GetClassDistribution()
baseEntropy := getBaseEntropy(classDist)
for _, s := range consideredAttributes {
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
maxGain = localEntropy
selectedAttribute = s
fmt.Printf("Gain: %.4f, selectedAttribute: %s\n", informationGain, f.GetAttr(selectedAttribute))
}
}
return f.GetAttr(selectedAttribute)
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
}
type RandomTree struct {
@ -95,13 +45,14 @@ func NewRandomTree(attrs int) *RandomTree {
nil,
RandomTreeRuleGenerator{
attrs,
InformationGainRuleGenerator{},
},
}
}
// Train builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
rt.Root = InferDecisionTree(from, &rt.Rule)
rt.Root = InferID3Tree(from, &rt.Rule)
}
// Predict returns a set of Instances containing predictions

15
trees/tennis.csv Normal file
View File

@ -0,0 +1,15 @@
outlook,temp,humidity,windy,play
sunny,hot,high,false,no
sunny,hot,high,true,no
overcast,hot,high,false,yes
rainy,mild,high,false,yes
rainy,cool,normal,false,yes
rainy,cool,normal,true,no
overcast,cool,normal,true,yes
sunny,mild,high,false,no
sunny,cool,normal,false,yes
rainy,mild,normal,false,yes
sunny,mild,normal,true,yes
overcast,mild,high,true,yes
overcast,hot,normal,false,yes
rainy,mild,high,true,no
1 outlook temp humidity windy play
2 sunny hot high false no
3 sunny hot high true no
4 overcast hot high false yes
5 rainy mild high false yes
6 rainy cool normal false yes
7 rainy cool normal true no
8 overcast cool normal true yes
9 sunny mild high false no
10 sunny cool normal false yes
11 rainy mild normal false yes
12 sunny mild normal true yes
13 overcast mild high true yes
14 overcast hot normal false yes
15 rainy mild high true no

View File

@ -22,7 +22,7 @@ func TestRandomTree(testEnv *testing.T) {
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferDecisionTree(inst, r)
root := InferID3Tree(inst, r)
fmt.Println(root)
}
@ -40,7 +40,7 @@ func TestRandomTreeClassification(testEnv *testing.T) {
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferDecisionTree(insts[0], r)
root := InferID3Tree(insts[0], r)
fmt.Println(root)
predictions := root.Predict(insts[1])
fmt.Println(predictions)
@ -91,3 +91,56 @@ func TestInformationGain(testEnv *testing.T) {
testEnv.Error(entropy)
}
}
func TestID3Inference(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
if err != nil {
panic(err)
}
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(inst, rule)
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
testEnv.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
testEnv.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
testEnv.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
testEnv.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
testEnv.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
testEnv.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
testEnv.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
testEnv.Error(overcastChild)
}
}