1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/trees/entropy.go

106 lines
2.8 KiB
Go
Raw Normal View History

2014-05-17 17:28:51 +01:00
package trees
import (
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
2014-05-17 17:28:51 +01:00
"math"
)
//
// Information gain rule generator
//
type InformationGainRuleGenerator struct {
}
// GenerateSplitAttribute returns the non-class Attribute which maximises the
2014-05-17 17:28:51 +01:00
// information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
2014-08-02 16:22:15 +01:00
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitAttributeFromSelection(candidates, f)
2014-05-17 17:28:51 +01:00
}
// GetSplitAttributeFromSelection returns the class Attribute which maximises
2014-05-17 17:28:51 +01:00
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
2014-08-02 16:22:15 +01:00
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
var selectedAttribute base.Attribute
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
2014-05-17 17:28:51 +01:00
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
// Compute the base entropy
2014-08-02 16:22:15 +01:00
classDist := base.GetClassDistribution(f)
2014-05-17 17:28:51 +01:00
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
2014-08-02 16:22:15 +01:00
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
2014-05-17 17:28:51 +01:00
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
maxGain = informationGain
selectedAttribute = s
}
}
// Pick the one which maximises IG
2014-08-02 16:22:15 +01:00
return selectedAttribute
2014-05-17 17:28:51 +01:00
}
//
// Entropy functions
//
// getSplitEntropy determines the entropy of the target
// class distribution after splitting on an base.Attribute
func getSplitEntropy(s map[string]map[string]int) float64 {
ret := 0.0
count := 0
for a := range s {
for c := range s[a] {
count += s[a][c]
}
}
for a := range s {
total := 0.0
for c := range s[a] {
total += float64(s[a][c])
}
for c := range s[a] {
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
}
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
}
return ret
}
// getBaseEntropy determines the entropy of the target
// class distribution before splitting on an base.Attribute
func getBaseEntropy(s map[string]int) float64 {
ret := 0.0
count := 0
for k := range s {
count += s[k]
}
for k := range s {
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
}
return ret
}