mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
ID3 algorithm working
This commit is contained in:
parent
cf165695c8
commit
db3ac3c695
101
trees/entropy.go
Normal file
101
trees/entropy.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package trees
|
||||||
|
|
||||||
|
import (
|
||||||
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Information gain rule generator
|
||||||
|
//
|
||||||
|
|
||||||
|
type InformationGainRuleGenerator struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSplitAttribute returns the non-class Attribute which maximises the
|
||||||
|
// information gain.
|
||||||
|
//
|
||||||
|
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
||||||
|
// variable will panic()
|
||||||
|
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||||
|
allAttributes := make([]int, 0)
|
||||||
|
for i := 0; i < f.Cols; i++ {
|
||||||
|
if i != f.ClassIndex {
|
||||||
|
allAttributes = append(allAttributes, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.GetSplitAttributeFromSelection(allAttributes, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSplitAttribute from selection returns the class Attribute which maximises
|
||||||
|
// the information gain amongst consideredAttributes
|
||||||
|
//
|
||||||
|
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
||||||
|
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
|
||||||
|
|
||||||
|
// Next step is to compute the information gain at this node
|
||||||
|
// for each randomly chosen attribute, and pick the one
|
||||||
|
// which maximises it
|
||||||
|
maxGain := math.Inf(-1)
|
||||||
|
selectedAttribute := -1
|
||||||
|
|
||||||
|
// Compute the base entropy
|
||||||
|
classDist := f.GetClassDistribution()
|
||||||
|
baseEntropy := getBaseEntropy(classDist)
|
||||||
|
|
||||||
|
// Compute the information gain for each attribute
|
||||||
|
for _, s := range consideredAttributes {
|
||||||
|
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||||
|
localEntropy := getSplitEntropy(proposedClassDist)
|
||||||
|
informationGain := baseEntropy - localEntropy
|
||||||
|
if informationGain > maxGain {
|
||||||
|
maxGain = informationGain
|
||||||
|
selectedAttribute = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pick the one which maximises IG
|
||||||
|
|
||||||
|
return f.GetAttr(selectedAttribute)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Entropy functions
|
||||||
|
//
|
||||||
|
|
||||||
|
// getSplitEntropy determines the entropy of the target
|
||||||
|
// class distribution after splitting on an base.Attribute
|
||||||
|
func getSplitEntropy(s map[string]map[string]int) float64 {
|
||||||
|
ret := 0.0
|
||||||
|
count := 0
|
||||||
|
for a := range s {
|
||||||
|
for c := range s[a] {
|
||||||
|
count += s[a][c]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for a := range s {
|
||||||
|
total := 0.0
|
||||||
|
for c := range s[a] {
|
||||||
|
total += float64(s[a][c])
|
||||||
|
}
|
||||||
|
for c := range s[a] {
|
||||||
|
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
||||||
|
}
|
||||||
|
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBaseEntropy determines the entropy of the target
|
||||||
|
// class distribution before splitting on an base.Attribute
|
||||||
|
func getBaseEntropy(s map[string]int) float64 {
|
||||||
|
ret := 0.0
|
||||||
|
count := 0
|
||||||
|
for k := range s {
|
||||||
|
count += s[k]
|
||||||
|
}
|
||||||
|
for k := range s {
|
||||||
|
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
@ -1,9 +1,9 @@
|
|||||||
package trees
|
package trees
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
||||||
@ -32,9 +32,9 @@ type DecisionTreeNode struct {
|
|||||||
ClassAttr *base.Attribute
|
ClassAttr *base.Attribute
|
||||||
}
|
}
|
||||||
|
|
||||||
// InferDecisionTree builds a decision tree using a RuleGenerator
|
// InferID3Tree builds a decision tree using a RuleGenerator
|
||||||
// from a set of Instances
|
// from a set of Instances (implements the ID3 algorithm)
|
||||||
func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||||
// Count the number of classes at this node
|
// Count the number of classes at this node
|
||||||
classes := from.CountClassValues()
|
classes := from.CountClassValues()
|
||||||
// If there's only one class, return a DecisionTreeLeaf with
|
// If there's only one class, return a DecisionTreeLeaf with
|
||||||
@ -96,25 +96,39 @@ func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNo
|
|||||||
// Create new children from these attributes
|
// Create new children from these attributes
|
||||||
for k := range splitInstances {
|
for k := range splitInstances {
|
||||||
newInstances := splitInstances[k]
|
newInstances := splitInstances[k]
|
||||||
ret.Children[k] = InferDecisionTree(newInstances, with)
|
ret.Children[k] = InferID3Tree(newInstances, with)
|
||||||
}
|
}
|
||||||
ret.SplitAttr = splitOnAttribute
|
ret.SplitAttr = splitOnAttribute
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DecisionTreeNode) getNestedString(level int) string {
|
||||||
|
buf := bytes.NewBuffer(nil)
|
||||||
|
tmp := bytes.NewBuffer(nil)
|
||||||
|
for i := 0; i < level; i++ {
|
||||||
|
tmp.WriteString("\t")
|
||||||
|
}
|
||||||
|
buf.WriteString(tmp.String())
|
||||||
|
if d.Children == nil {
|
||||||
|
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
|
||||||
|
} else {
|
||||||
|
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
|
||||||
|
for k := range d.Children {
|
||||||
|
buf.WriteString("\n")
|
||||||
|
buf.WriteString(tmp.String())
|
||||||
|
buf.WriteString("\t")
|
||||||
|
buf.WriteString(k)
|
||||||
|
buf.WriteString("\n")
|
||||||
|
buf.WriteString(d.Children[k].getNestedString(level + 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
// String returns a human-readable representation of a given node
|
// String returns a human-readable representation of a given node
|
||||||
// and it's children
|
// and it's children
|
||||||
func (d *DecisionTreeNode) String() string {
|
func (d *DecisionTreeNode) String() string {
|
||||||
children := make([]string, 0)
|
return d.getNestedString(0)
|
||||||
if d.Children != nil {
|
|
||||||
for k := range d.Children {
|
|
||||||
childStr := fmt.Sprintf("Rule(%s -> %s)", k, d.Children[k])
|
|
||||||
children = append(children, childStr)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("(%s(%s))", d.SplitAttr, strings.Join(children, "\n\t"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("Leaf(%s (%s))", d.Class, d.ClassDist)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
@ -3,47 +3,18 @@ package trees
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
base "github.com/sjwhitworth/golearn/base"
|
base "github.com/sjwhitworth/golearn/base"
|
||||||
"math"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RandomTreeRuleGenerator struct {
|
type RandomTreeRuleGenerator struct {
|
||||||
Attributes int
|
Attributes int
|
||||||
}
|
internalRule InformationGainRuleGenerator
|
||||||
|
|
||||||
func getSplitEntropy(s map[string]map[string]int) float64 {
|
|
||||||
ret := 0.0
|
|
||||||
count := 0
|
|
||||||
for a := range s {
|
|
||||||
total := 0.0
|
|
||||||
for c := range s[a] {
|
|
||||||
ret -= float64(s[a][c]) * math.Log(float64(s[a][c])) / math.Log(2)
|
|
||||||
total += float64(s[a][c])
|
|
||||||
count += s[a][c]
|
|
||||||
}
|
|
||||||
ret += total * math.Log(total) / math.Log(2)
|
|
||||||
}
|
|
||||||
return ret / float64(count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaseEntropy(s map[string]int) float64 {
|
|
||||||
ret := 0.0
|
|
||||||
count := 0
|
|
||||||
for k := range s {
|
|
||||||
count += s[k]
|
|
||||||
}
|
|
||||||
for k := range s {
|
|
||||||
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// So WEKA returns a couple of possible attributes and evaluates
|
// So WEKA returns a couple of possible attributes and evaluates
|
||||||
// the split criteria on each
|
// the split criteria on each
|
||||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||||
|
|
||||||
fmt.Println("GenerateSplitAttribute", r.Attributes)
|
|
||||||
|
|
||||||
// First step is to generate the random attributes that we'll consider
|
// First step is to generate the random attributes that we'll consider
|
||||||
maximumAttribute := f.GetAttributeCount()
|
maximumAttribute := f.GetAttributeCount()
|
||||||
consideredAttributes := make([]int, r.Attributes)
|
consideredAttributes := make([]int, r.Attributes)
|
||||||
@ -59,28 +30,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next step is to compute the information gain at this node
|
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||||
// for each randomly chosen attribute, and pick the one
|
|
||||||
// which maximises it
|
|
||||||
maxGain := math.Inf(-1)
|
|
||||||
selectedAttribute := -1
|
|
||||||
|
|
||||||
// Compute the base entropy
|
|
||||||
classDist := f.GetClassDistribution()
|
|
||||||
baseEntropy := getBaseEntropy(classDist)
|
|
||||||
|
|
||||||
for _, s := range consideredAttributes {
|
|
||||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
|
||||||
localEntropy := getSplitEntropy(proposedClassDist)
|
|
||||||
informationGain := baseEntropy - localEntropy
|
|
||||||
if informationGain > maxGain {
|
|
||||||
maxGain = localEntropy
|
|
||||||
selectedAttribute = s
|
|
||||||
fmt.Printf("Gain: %.4f, selectedAttribute: %s\n", informationGain, f.GetAttr(selectedAttribute))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.GetAttr(selectedAttribute)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RandomTree struct {
|
type RandomTree struct {
|
||||||
@ -95,13 +45,14 @@ func NewRandomTree(attrs int) *RandomTree {
|
|||||||
nil,
|
nil,
|
||||||
RandomTreeRuleGenerator{
|
RandomTreeRuleGenerator{
|
||||||
attrs,
|
attrs,
|
||||||
|
InformationGainRuleGenerator{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Train builds a RandomTree suitable for prediction
|
// Train builds a RandomTree suitable for prediction
|
||||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||||
rt.Root = InferDecisionTree(from, &rt.Rule)
|
rt.Root = InferID3Tree(from, &rt.Rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict returns a set of Instances containing predictions
|
// Predict returns a set of Instances containing predictions
|
||||||
|
15
trees/tennis.csv
Normal file
15
trees/tennis.csv
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
outlook,temp,humidity,windy,play
|
||||||
|
sunny,hot,high,false,no
|
||||||
|
sunny,hot,high,true,no
|
||||||
|
overcast,hot,high,false,yes
|
||||||
|
rainy,mild,high,false,yes
|
||||||
|
rainy,cool,normal,false,yes
|
||||||
|
rainy,cool,normal,true,no
|
||||||
|
overcast,cool,normal,true,yes
|
||||||
|
sunny,mild,high,false,no
|
||||||
|
sunny,cool,normal,false,yes
|
||||||
|
rainy,mild,normal,false,yes
|
||||||
|
sunny,mild,normal,true,yes
|
||||||
|
overcast,mild,high,true,yes
|
||||||
|
overcast,hot,normal,false,yes
|
||||||
|
rainy,mild,high,true,no
|
|
@ -22,7 +22,7 @@ func TestRandomTree(testEnv *testing.T) {
|
|||||||
fmt.Println(inst)
|
fmt.Println(inst)
|
||||||
r := new(RandomTreeRuleGenerator)
|
r := new(RandomTreeRuleGenerator)
|
||||||
r.Attributes = 2
|
r.Attributes = 2
|
||||||
root := InferDecisionTree(inst, r)
|
root := InferID3Tree(inst, r)
|
||||||
fmt.Println(root)
|
fmt.Println(root)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
|||||||
fmt.Println(inst)
|
fmt.Println(inst)
|
||||||
r := new(RandomTreeRuleGenerator)
|
r := new(RandomTreeRuleGenerator)
|
||||||
r.Attributes = 2
|
r.Attributes = 2
|
||||||
root := InferDecisionTree(insts[0], r)
|
root := InferID3Tree(insts[0], r)
|
||||||
fmt.Println(root)
|
fmt.Println(root)
|
||||||
predictions := root.Predict(insts[1])
|
predictions := root.Predict(insts[1])
|
||||||
fmt.Println(predictions)
|
fmt.Println(predictions)
|
||||||
@ -91,3 +91,56 @@ func TestInformationGain(testEnv *testing.T) {
|
|||||||
testEnv.Error(entropy)
|
testEnv.Error(entropy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestID3Inference(testEnv *testing.T) {
|
||||||
|
|
||||||
|
// Import the "PlayTennis" dataset
|
||||||
|
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the decision tree
|
||||||
|
rule := new(InformationGainRuleGenerator)
|
||||||
|
root := InferID3Tree(inst, rule)
|
||||||
|
|
||||||
|
// Verify the tree
|
||||||
|
// First attribute should be "outlook"
|
||||||
|
if root.SplitAttr.GetName() != "outlook" {
|
||||||
|
testEnv.Error(root)
|
||||||
|
}
|
||||||
|
sunnyChild := root.Children["sunny"]
|
||||||
|
overcastChild := root.Children["overcast"]
|
||||||
|
rainyChild := root.Children["rainy"]
|
||||||
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||||
|
testEnv.Error(sunnyChild)
|
||||||
|
}
|
||||||
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||||
|
testEnv.Error(rainyChild)
|
||||||
|
}
|
||||||
|
if overcastChild.SplitAttr != nil {
|
||||||
|
testEnv.Error(overcastChild)
|
||||||
|
}
|
||||||
|
|
||||||
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||||
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||||
|
if sunnyLeafHigh.Class != "no" {
|
||||||
|
testEnv.Error(sunnyLeafHigh)
|
||||||
|
}
|
||||||
|
if sunnyLeafNormal.Class != "yes" {
|
||||||
|
testEnv.Error(sunnyLeafNormal)
|
||||||
|
}
|
||||||
|
|
||||||
|
windyLeafFalse := rainyChild.Children["false"]
|
||||||
|
windyLeafTrue := rainyChild.Children["true"]
|
||||||
|
if windyLeafFalse.Class != "yes" {
|
||||||
|
testEnv.Error(windyLeafFalse)
|
||||||
|
}
|
||||||
|
if windyLeafTrue.Class != "no" {
|
||||||
|
testEnv.Error(windyLeafTrue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if overcastChild.Class != "yes" {
|
||||||
|
testEnv.Error(overcastChild)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user