mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
ID3 algorithm working
This commit is contained in:
parent
cf165695c8
commit
db3ac3c695
101
trees/entropy.go
Normal file
101
trees/entropy.go
Normal file
@ -0,0 +1,101 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math"
|
||||
)
|
||||
|
||||
//
|
||||
// Information gain rule generator
|
||||
//
|
||||
|
||||
type InformationGainRuleGenerator struct {
|
||||
}
|
||||
|
||||
// GetSplitAttribute returns the non-class Attribute which maximises the
|
||||
// information gain.
|
||||
//
|
||||
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
||||
// variable will panic()
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
allAttributes := make([]int, 0)
|
||||
for i := 0; i < f.Cols; i++ {
|
||||
if i != f.ClassIndex {
|
||||
allAttributes = append(allAttributes, i)
|
||||
}
|
||||
}
|
||||
return r.GetSplitAttributeFromSelection(allAttributes, f)
|
||||
}
|
||||
|
||||
// GetSplitAttribute from selection returns the class Attribute which maximises
|
||||
// the information gain amongst consideredAttributes
|
||||
//
|
||||
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
|
||||
|
||||
// Next step is to compute the information gain at this node
|
||||
// for each randomly chosen attribute, and pick the one
|
||||
// which maximises it
|
||||
maxGain := math.Inf(-1)
|
||||
selectedAttribute := -1
|
||||
|
||||
// Compute the base entropy
|
||||
classDist := f.GetClassDistribution()
|
||||
baseEntropy := getBaseEntropy(classDist)
|
||||
|
||||
// Compute the information gain for each attribute
|
||||
for _, s := range consideredAttributes {
|
||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||
localEntropy := getSplitEntropy(proposedClassDist)
|
||||
informationGain := baseEntropy - localEntropy
|
||||
if informationGain > maxGain {
|
||||
maxGain = informationGain
|
||||
selectedAttribute = s
|
||||
}
|
||||
}
|
||||
|
||||
// Pick the one which maximises IG
|
||||
|
||||
return f.GetAttr(selectedAttribute)
|
||||
}
|
||||
|
||||
//
|
||||
// Entropy functions
|
||||
//
|
||||
|
||||
// getSplitEntropy determines the entropy of the target
|
||||
// class distribution after splitting on an base.Attribute
|
||||
func getSplitEntropy(s map[string]map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for a := range s {
|
||||
for c := range s[a] {
|
||||
count += s[a][c]
|
||||
}
|
||||
}
|
||||
for a := range s {
|
||||
total := 0.0
|
||||
for c := range s[a] {
|
||||
total += float64(s[a][c])
|
||||
}
|
||||
for c := range s[a] {
|
||||
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
||||
}
|
||||
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// getBaseEntropy determines the entropy of the target
|
||||
// class distribution before splitting on an base.Attribute
|
||||
func getBaseEntropy(s map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for k := range s {
|
||||
count += s[k]
|
||||
}
|
||||
for k := range s {
|
||||
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
||||
@ -32,9 +32,9 @@ type DecisionTreeNode struct {
|
||||
ClassAttr *base.Attribute
|
||||
}
|
||||
|
||||
// InferDecisionTree builds a decision tree using a RuleGenerator
|
||||
// from a set of Instances
|
||||
func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
// InferID3Tree builds a decision tree using a RuleGenerator
|
||||
// from a set of Instances (implements the ID3 algorithm)
|
||||
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
// Count the number of classes at this node
|
||||
classes := from.CountClassValues()
|
||||
// If there's only one class, return a DecisionTreeLeaf with
|
||||
@ -96,25 +96,39 @@ func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNo
|
||||
// Create new children from these attributes
|
||||
for k := range splitInstances {
|
||||
newInstances := splitInstances[k]
|
||||
ret.Children[k] = InferDecisionTree(newInstances, with)
|
||||
ret.Children[k] = InferID3Tree(newInstances, with)
|
||||
}
|
||||
ret.SplitAttr = splitOnAttribute
|
||||
return ret
|
||||
}
|
||||
|
||||
func (d *DecisionTreeNode) getNestedString(level int) string {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
tmp := bytes.NewBuffer(nil)
|
||||
for i := 0; i < level; i++ {
|
||||
tmp.WriteString("\t")
|
||||
}
|
||||
buf.WriteString(tmp.String())
|
||||
if d.Children == nil {
|
||||
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
|
||||
for k := range d.Children {
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(tmp.String())
|
||||
buf.WriteString("\t")
|
||||
buf.WriteString(k)
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(d.Children[k].getNestedString(level + 1))
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of a given node
|
||||
// and it's children
|
||||
func (d *DecisionTreeNode) String() string {
|
||||
children := make([]string, 0)
|
||||
if d.Children != nil {
|
||||
for k := range d.Children {
|
||||
childStr := fmt.Sprintf("Rule(%s -> %s)", k, d.Children[k])
|
||||
children = append(children, childStr)
|
||||
}
|
||||
return fmt.Sprintf("(%s(%s))", d.SplitAttr, strings.Join(children, "\n\t"))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Leaf(%s (%s))", d.Class, d.ClassDist)
|
||||
return d.getNestedString(0)
|
||||
}
|
||||
|
||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
@ -3,47 +3,18 @@ package trees
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
type RandomTreeRuleGenerator struct {
|
||||
Attributes int
|
||||
}
|
||||
|
||||
func getSplitEntropy(s map[string]map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for a := range s {
|
||||
total := 0.0
|
||||
for c := range s[a] {
|
||||
ret -= float64(s[a][c]) * math.Log(float64(s[a][c])) / math.Log(2)
|
||||
total += float64(s[a][c])
|
||||
count += s[a][c]
|
||||
}
|
||||
ret += total * math.Log(total) / math.Log(2)
|
||||
}
|
||||
return ret / float64(count)
|
||||
}
|
||||
|
||||
func getBaseEntropy(s map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for k := range s {
|
||||
count += s[k]
|
||||
}
|
||||
for k := range s {
|
||||
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
Attributes int
|
||||
internalRule InformationGainRuleGenerator
|
||||
}
|
||||
|
||||
// So WEKA returns a couple of possible attributes and evaluates
|
||||
// the split criteria on each
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
|
||||
fmt.Println("GenerateSplitAttribute", r.Attributes)
|
||||
|
||||
// First step is to generate the random attributes that we'll consider
|
||||
maximumAttribute := f.GetAttributeCount()
|
||||
consideredAttributes := make([]int, r.Attributes)
|
||||
@ -59,28 +30,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
|
||||
}
|
||||
}
|
||||
|
||||
// Next step is to compute the information gain at this node
|
||||
// for each randomly chosen attribute, and pick the one
|
||||
// which maximises it
|
||||
maxGain := math.Inf(-1)
|
||||
selectedAttribute := -1
|
||||
|
||||
// Compute the base entropy
|
||||
classDist := f.GetClassDistribution()
|
||||
baseEntropy := getBaseEntropy(classDist)
|
||||
|
||||
for _, s := range consideredAttributes {
|
||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||
localEntropy := getSplitEntropy(proposedClassDist)
|
||||
informationGain := baseEntropy - localEntropy
|
||||
if informationGain > maxGain {
|
||||
maxGain = localEntropy
|
||||
selectedAttribute = s
|
||||
fmt.Printf("Gain: %.4f, selectedAttribute: %s\n", informationGain, f.GetAttr(selectedAttribute))
|
||||
}
|
||||
}
|
||||
|
||||
return f.GetAttr(selectedAttribute)
|
||||
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||
}
|
||||
|
||||
type RandomTree struct {
|
||||
@ -95,13 +45,14 @@ func NewRandomTree(attrs int) *RandomTree {
|
||||
nil,
|
||||
RandomTreeRuleGenerator{
|
||||
attrs,
|
||||
InformationGainRuleGenerator{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Train builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||
rt.Root = InferDecisionTree(from, &rt.Rule)
|
||||
rt.Root = InferID3Tree(from, &rt.Rule)
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
|
15
trees/tennis.csv
Normal file
15
trees/tennis.csv
Normal file
@ -0,0 +1,15 @@
|
||||
outlook,temp,humidity,windy,play
|
||||
sunny,hot,high,false,no
|
||||
sunny,hot,high,true,no
|
||||
overcast,hot,high,false,yes
|
||||
rainy,mild,high,false,yes
|
||||
rainy,cool,normal,false,yes
|
||||
rainy,cool,normal,true,no
|
||||
overcast,cool,normal,true,yes
|
||||
sunny,mild,high,false,no
|
||||
sunny,cool,normal,false,yes
|
||||
rainy,mild,normal,false,yes
|
||||
sunny,mild,normal,true,yes
|
||||
overcast,mild,high,true,yes
|
||||
overcast,hot,normal,false,yes
|
||||
rainy,mild,high,true,no
|
|
@ -22,7 +22,7 @@ func TestRandomTree(testEnv *testing.T) {
|
||||
fmt.Println(inst)
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferDecisionTree(inst, r)
|
||||
root := InferID3Tree(inst, r)
|
||||
fmt.Println(root)
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
||||
fmt.Println(inst)
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferDecisionTree(insts[0], r)
|
||||
root := InferID3Tree(insts[0], r)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
@ -91,3 +91,56 @@ func TestInformationGain(testEnv *testing.T) {
|
||||
testEnv.Error(entropy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestID3Inference(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(inst, rule)
|
||||
|
||||
// Verify the tree
|
||||
// First attribute should be "outlook"
|
||||
if root.SplitAttr.GetName() != "outlook" {
|
||||
testEnv.Error(root)
|
||||
}
|
||||
sunnyChild := root.Children["sunny"]
|
||||
overcastChild := root.Children["overcast"]
|
||||
rainyChild := root.Children["rainy"]
|
||||
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||
testEnv.Error(sunnyChild)
|
||||
}
|
||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||
testEnv.Error(rainyChild)
|
||||
}
|
||||
if overcastChild.SplitAttr != nil {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
|
||||
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||
if sunnyLeafHigh.Class != "no" {
|
||||
testEnv.Error(sunnyLeafHigh)
|
||||
}
|
||||
if sunnyLeafNormal.Class != "yes" {
|
||||
testEnv.Error(sunnyLeafNormal)
|
||||
}
|
||||
|
||||
windyLeafFalse := rainyChild.Children["false"]
|
||||
windyLeafTrue := rainyChild.Children["true"]
|
||||
if windyLeafFalse.Class != "yes" {
|
||||
testEnv.Error(windyLeafFalse)
|
||||
}
|
||||
if windyLeafTrue.Class != "no" {
|
||||
testEnv.Error(windyLeafTrue)
|
||||
}
|
||||
|
||||
if overcastChild.Class != "yes" {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user