mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00

This patch adds: * Gini index and information gain ratio as DecisionTree split options; * handling for numeric Attributes (split point chosen naïvely on the basis of maximum entropy); * A couple of additional utility functions in base/ * A new dataset (see sources.txt) for testing. Performance on Iris performs markedly without discretisation.
185 lines
4.9 KiB
Go
185 lines
4.9 KiB
Go
package trees
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math"
|
|
"sort"
|
|
)
|
|
|
|
//
|
|
// Information gain rule generator
|
|
//
|
|
|
|
// InformationGainRuleGenerator generates DecisionTreeRules which
|
|
// maximize information gain at each node.
|
|
type InformationGainRuleGenerator struct {
|
|
}
|
|
|
|
// GenerateSplitRule returns a DecisionTreeNode based on a non-class Attribute
|
|
// which maximises the information gain.
|
|
//
|
|
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
|
// variable will panic()
|
|
func (r *InformationGainRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
|
|
|
|
attrs := f.AllAttributes()
|
|
classAttrs := f.AllClassAttributes()
|
|
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
|
|
|
|
return r.GetSplitRuleFromSelection(candidates, f)
|
|
}
|
|
|
|
// GetSplitRuleFromSelection returns a DecisionTreeRule which maximises
|
|
// the information gain amongst the considered Attributes.
|
|
//
|
|
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
|
func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
|
|
|
|
var selectedAttribute base.Attribute
|
|
|
|
// Parameter check
|
|
if len(consideredAttributes) == 0 {
|
|
panic("More Attributes should be considered")
|
|
}
|
|
|
|
// Next step is to compute the information gain at this node
|
|
// for each randomly chosen attribute, and pick the one
|
|
// which maximises it
|
|
maxGain := math.Inf(-1)
|
|
selectedVal := math.Inf(1)
|
|
|
|
// Compute the base entropy
|
|
classDist := base.GetClassDistribution(f)
|
|
baseEntropy := getBaseEntropy(classDist)
|
|
|
|
// Compute the information gain for each attribute
|
|
for _, s := range consideredAttributes {
|
|
var informationGain float64
|
|
var splitVal float64
|
|
if fAttr, ok := s.(*base.FloatAttribute); ok {
|
|
var attributeEntropy float64
|
|
attributeEntropy, splitVal = getNumericAttributeEntropy(f, fAttr)
|
|
informationGain = baseEntropy - attributeEntropy
|
|
} else {
|
|
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
|
|
localEntropy := getSplitEntropy(proposedClassDist)
|
|
informationGain = baseEntropy - localEntropy
|
|
}
|
|
|
|
if informationGain > maxGain {
|
|
maxGain = informationGain
|
|
selectedAttribute = s
|
|
selectedVal = splitVal
|
|
}
|
|
}
|
|
|
|
// Pick the one which maximises IG
|
|
return &DecisionTreeRule{selectedAttribute, selectedVal}
|
|
}
|
|
|
|
//
|
|
// Entropy functions
|
|
//
|
|
|
|
type numericSplitRef struct {
|
|
val float64
|
|
class string
|
|
}
|
|
|
|
type splitVec []numericSplitRef
|
|
|
|
func (a splitVec) Len() int { return len(a) }
|
|
func (a splitVec) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
|
func (a splitVec) Less(i, j int) bool { return a[i].val < a[j].val }
|
|
|
|
func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {
|
|
|
|
// Resolve Attribute
|
|
attrSpec, err := f.GetAttribute(attr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Build sortable vector
|
|
_, rows := f.Size()
|
|
refs := make([]numericSplitRef, rows)
|
|
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
|
|
cls := base.GetClass(f, row)
|
|
v := base.UnpackBytesToFloat(val[0])
|
|
refs[row] = numericSplitRef{v, cls}
|
|
return true, nil
|
|
})
|
|
|
|
// Sort
|
|
sort.Sort(splitVec(refs))
|
|
|
|
generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
|
|
presplit := make(map[string]int)
|
|
postplit := make(map[string]int)
|
|
for _, i := range refs {
|
|
if i.val < val {
|
|
presplit[i.class]++
|
|
} else {
|
|
postplit[i.class]++
|
|
}
|
|
}
|
|
ret := make(map[string]map[string]int)
|
|
ret["0"] = presplit
|
|
ret["1"] = postplit
|
|
return ret
|
|
}
|
|
|
|
minSplitEntropy := math.Inf(1)
|
|
minSplitVal := math.Inf(1)
|
|
// Consider each possible function
|
|
for i := 0; i < len(refs)-1; i++ {
|
|
val := refs[i].val + refs[i+1].val
|
|
val /= 2
|
|
splitDist := generateCandidateSplitDistribution(val)
|
|
splitEntropy := getSplitEntropy(splitDist)
|
|
if splitEntropy < minSplitEntropy {
|
|
minSplitEntropy = splitEntropy
|
|
minSplitVal = val
|
|
}
|
|
}
|
|
|
|
return minSplitEntropy, minSplitVal
|
|
}
|
|
|
|
// getSplitEntropy determines the entropy of the target
|
|
// class distribution after splitting on an base.Attribute
|
|
func getSplitEntropy(s map[string]map[string]int) float64 {
|
|
ret := 0.0
|
|
count := 0
|
|
for a := range s {
|
|
for c := range s[a] {
|
|
count += s[a][c]
|
|
}
|
|
}
|
|
for a := range s {
|
|
total := 0.0
|
|
for c := range s[a] {
|
|
total += float64(s[a][c])
|
|
}
|
|
for c := range s[a] {
|
|
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
|
}
|
|
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// getBaseEntropy determines the entropy of the target
|
|
// class distribution before splitting on an base.Attribute
|
|
func getBaseEntropy(s map[string]int) float64 {
|
|
ret := 0.0
|
|
count := 0
|
|
for k := range s {
|
|
count += s[k]
|
|
}
|
|
for k := range s {
|
|
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
|
}
|
|
return ret
|
|
}
|