mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00

Avoid quadratic loop in getNumericAttributeEntropy. We don't need to recalculate whole distribution for each split, just move changed values. Also use array of slices instead of map of maps of strings to avoid map overhead. For our case I see time reductions from 100+ hours to 50 minutes. I've added benchmark with synthetic data (iris.csv repeated 100 times) and it also shows a nice improvement: name old time/op new time/op delta RandomForestFit-8 117s ± 4% 0s ± 1% -99.61% (p=0.001 n=5+10) 0 is a rounding quirk of benchstat, it should be closer to 0.5s: name time/op RandomForestFit-8 460ms ± 1%
224 lines
5.8 KiB
Go
224 lines
5.8 KiB
Go
package trees
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math"
|
|
"sort"
|
|
)
|
|
|
|
//
|
|
// Information gain rule generator
|
|
//
|
|
|
|
// InformationGainRuleGenerator generates DecisionTreeRules which
|
|
// maximize information gain at each node.
|
|
type InformationGainRuleGenerator struct {
|
|
}
|
|
|
|
// GenerateSplitRule returns a DecisionTreeNode based on a non-class Attribute
|
|
// which maximises the information gain.
|
|
//
|
|
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
|
// variable will panic()
|
|
func (r *InformationGainRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
|
|
|
|
attrs := f.AllAttributes()
|
|
classAttrs := f.AllClassAttributes()
|
|
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
|
|
|
|
return r.GetSplitRuleFromSelection(candidates, f)
|
|
}
|
|
|
|
// GetSplitRuleFromSelection returns a DecisionTreeRule which maximises
|
|
// the information gain amongst the considered Attributes.
|
|
//
|
|
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
|
func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
|
|
|
|
var selectedAttribute base.Attribute
|
|
|
|
// Parameter check
|
|
if len(consideredAttributes) == 0 {
|
|
panic("More Attributes should be considered")
|
|
}
|
|
|
|
// Next step is to compute the information gain at this node
|
|
// for each randomly chosen attribute, and pick the one
|
|
// which maximises it
|
|
maxGain := math.Inf(-1)
|
|
selectedVal := math.Inf(1)
|
|
|
|
// Compute the base entropy
|
|
classDist := base.GetClassDistribution(f)
|
|
baseEntropy := getBaseEntropy(classDist)
|
|
|
|
// Compute the information gain for each attribute
|
|
for _, s := range consideredAttributes {
|
|
var informationGain float64
|
|
var splitVal float64
|
|
if fAttr, ok := s.(*base.FloatAttribute); ok {
|
|
var attributeEntropy float64
|
|
attributeEntropy, splitVal = getNumericAttributeEntropy(f, fAttr)
|
|
informationGain = baseEntropy - attributeEntropy
|
|
} else {
|
|
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
|
|
localEntropy := getSplitEntropy(proposedClassDist)
|
|
informationGain = baseEntropy - localEntropy
|
|
}
|
|
|
|
if informationGain > maxGain {
|
|
maxGain = informationGain
|
|
selectedAttribute = s
|
|
selectedVal = splitVal
|
|
}
|
|
}
|
|
|
|
// Pick the one which maximises IG
|
|
return &DecisionTreeRule{selectedAttribute, selectedVal}
|
|
}
|
|
|
|
//
|
|
// Entropy functions
|
|
//
|
|
|
|
type numericSplitRef struct {
|
|
val float64
|
|
class int
|
|
}
|
|
|
|
type splitVec []numericSplitRef
|
|
|
|
func (a splitVec) Len() int { return len(a) }
|
|
func (a splitVec) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
|
func (a splitVec) Less(i, j int) bool { return a[i].val < a[j].val }
|
|
|
|
func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {
|
|
|
|
// Resolve Attribute
|
|
attrSpec, err := f.GetAttribute(attr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Build sortable vector
|
|
_, rows := f.Size()
|
|
refs := make([]numericSplitRef, rows)
|
|
numClasses := 0
|
|
class2Int := make(map[string]int)
|
|
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
|
|
cls := base.GetClass(f, row)
|
|
i, ok := class2Int[cls]
|
|
if !ok {
|
|
i = numClasses
|
|
class2Int[cls] = i
|
|
numClasses++
|
|
}
|
|
v := base.UnpackBytesToFloat(val[0])
|
|
refs[row] = numericSplitRef{v, i}
|
|
return true, nil
|
|
})
|
|
|
|
sort.Sort(splitVec(refs))
|
|
|
|
minSplitEntropy := math.Inf(1)
|
|
minSplitVal := math.Inf(1)
|
|
prevVal := math.NaN()
|
|
prevInd := 0
|
|
|
|
splitDist := [2][]int{make([]int, numClasses), make([]int, numClasses)}
|
|
// Before first split all refs are not smaller than val
|
|
for _, x := range refs {
|
|
splitDist[1][x.class]++
|
|
}
|
|
|
|
// Consider each possible function
|
|
for i := 0; i < len(refs)-1; {
|
|
val := refs[i].val + refs[i+1].val
|
|
val /= 2
|
|
if val == prevVal {
|
|
i++
|
|
continue
|
|
}
|
|
// refs is sorted, so we only need to update values that are
|
|
// bigger than prevVal, but are lower than val
|
|
for j := prevInd; j < len(refs) && refs[j].val < val; j++ {
|
|
splitDist[0][refs[j].class]++
|
|
splitDist[1][refs[j].class]--
|
|
i++
|
|
prevInd++
|
|
}
|
|
prevVal = val
|
|
splitEntropy := getSplitEntropyFast(splitDist)
|
|
if splitEntropy < minSplitEntropy {
|
|
minSplitEntropy = splitEntropy
|
|
minSplitVal = val
|
|
}
|
|
}
|
|
|
|
return minSplitEntropy, minSplitVal
|
|
}
|
|
|
|
// getSplitEntropyFast determines the entropy of the target
|
|
// class distribution after splitting on an base.Attribute.
|
|
// It is similar to getSplitEntropy, but accepts array of slices,
|
|
// to avoid map access overhead.
|
|
func getSplitEntropyFast(s [2][]int) float64 {
|
|
ret := 0.0
|
|
count := 0
|
|
for a := range s {
|
|
for c := range s[a] {
|
|
count += s[a][c]
|
|
}
|
|
}
|
|
for a := range s {
|
|
total := 0.0
|
|
for c := range s[a] {
|
|
total += float64(s[a][c])
|
|
}
|
|
for c := range s[a] {
|
|
if s[a][c] != 0 {
|
|
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
|
}
|
|
}
|
|
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// getSplitEntropy determines the entropy of the target
|
|
// class distribution after splitting on an base.Attribute
|
|
func getSplitEntropy(s map[string]map[string]int) float64 {
|
|
ret := 0.0
|
|
count := 0
|
|
for a := range s {
|
|
for c := range s[a] {
|
|
count += s[a][c]
|
|
}
|
|
}
|
|
for a := range s {
|
|
total := 0.0
|
|
for c := range s[a] {
|
|
total += float64(s[a][c])
|
|
}
|
|
for c := range s[a] {
|
|
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
|
}
|
|
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// getBaseEntropy determines the entropy of the target
|
|
// class distribution before splitting on an base.Attribute
|
|
func getBaseEntropy(s map[string]int) float64 {
|
|
ret := 0.0
|
|
count := 0
|
|
for k := range s {
|
|
count += s[k]
|
|
}
|
|
for k := range s {
|
|
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
|
}
|
|
return ret
|
|
}
|