mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
trees: speed-up training
Avoid quadratic loop in getNumericAttributeEntropy. We don't need to recalculate whole distribution for each split, just move changed values. Also use array of slices instead of map of maps of strings to avoid map overhead. For our case I see time reductions from 100+ hours to 50 minutes. I've added benchmark with synthetic data (iris.csv repeated 100 times) and it also shows a nice improvement: name old time/op new time/op delta RandomForestFit-8 117s ± 4% 0s ± 1% -99.61% (p=0.001 n=5+10) 0 is a rounding quirk of benchstat, it should be closer to 0.5s: name time/op RandomForestFit-8 460ms ± 1%
This commit is contained in:
parent
623af61265
commit
676f69a426
15000
trees/benchdata.csv
Normal file
15000
trees/benchdata.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -83,7 +83,7 @@ func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttri
|
||||
|
||||
type numericSplitRef struct {
|
||||
val float64
|
||||
class string
|
||||
class int
|
||||
}
|
||||
|
||||
type splitVec []numericSplitRef
|
||||
@ -103,40 +103,52 @@ func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute)
|
||||
// Build sortable vector
|
||||
_, rows := f.Size()
|
||||
refs := make([]numericSplitRef, rows)
|
||||
numClasses := 0
|
||||
class2Int := make(map[string]int)
|
||||
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
|
||||
cls := base.GetClass(f, row)
|
||||
i, ok := class2Int[cls]
|
||||
if !ok {
|
||||
i = numClasses
|
||||
class2Int[cls] = i
|
||||
numClasses++
|
||||
}
|
||||
v := base.UnpackBytesToFloat(val[0])
|
||||
refs[row] = numericSplitRef{v, cls}
|
||||
refs[row] = numericSplitRef{v, i}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Sort
|
||||
sort.Sort(splitVec(refs))
|
||||
|
||||
generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
|
||||
presplit := make(map[string]int)
|
||||
postplit := make(map[string]int)
|
||||
for _, i := range refs {
|
||||
if i.val < val {
|
||||
presplit[i.class]++
|
||||
} else {
|
||||
postplit[i.class]++
|
||||
}
|
||||
}
|
||||
ret := make(map[string]map[string]int)
|
||||
ret["0"] = presplit
|
||||
ret["1"] = postplit
|
||||
return ret
|
||||
}
|
||||
|
||||
minSplitEntropy := math.Inf(1)
|
||||
minSplitVal := math.Inf(1)
|
||||
prevVal := math.NaN()
|
||||
prevInd := 0
|
||||
|
||||
splitDist := [2][]int{make([]int, numClasses), make([]int, numClasses)}
|
||||
// Before first split all refs are not smaller than val
|
||||
for _, x := range refs {
|
||||
splitDist[1][x.class]++
|
||||
}
|
||||
|
||||
// Consider each possible function
|
||||
for i := 0; i < len(refs)-1; i++ {
|
||||
for i := 0; i < len(refs)-1; {
|
||||
val := refs[i].val + refs[i+1].val
|
||||
val /= 2
|
||||
splitDist := generateCandidateSplitDistribution(val)
|
||||
splitEntropy := getSplitEntropy(splitDist)
|
||||
if val == prevVal {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// refs is sorted, so we only need to update values that are
|
||||
// bigger than prevVal, but are lower than val
|
||||
for j := prevInd; j < len(refs) && refs[j].val < val; j++ {
|
||||
splitDist[0][refs[j].class]++
|
||||
splitDist[1][refs[j].class]--
|
||||
i++
|
||||
prevInd++
|
||||
}
|
||||
prevVal = val
|
||||
splitEntropy := getSplitEntropyFast(splitDist)
|
||||
if splitEntropy < minSplitEntropy {
|
||||
minSplitEntropy = splitEntropy
|
||||
minSplitVal = val
|
||||
@ -146,6 +158,33 @@ func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute)
|
||||
return minSplitEntropy, minSplitVal
|
||||
}
|
||||
|
||||
// getSplitEntropyFast determines the entropy of the target
|
||||
// class distribution after splitting on an base.Attribute.
|
||||
// It is similar to getSplitEntropy, but accepts array of slices,
|
||||
// to avoid map access overhead.
|
||||
func getSplitEntropyFast(s [2][]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for a := range s {
|
||||
for c := range s[a] {
|
||||
count += s[a][c]
|
||||
}
|
||||
}
|
||||
for a := range s {
|
||||
total := 0.0
|
||||
for c := range s[a] {
|
||||
total += float64(s[a][c])
|
||||
}
|
||||
for c := range s[a] {
|
||||
if s[a][c] != 0 {
|
||||
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
||||
}
|
||||
}
|
||||
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// getSplitEntropy determines the entropy of the target
|
||||
// class distribution after splitting on an base.Attribute
|
||||
func getSplitEntropy(s map[string]map[string]int) float64 {
|
||||
|
20
trees/tree_bench_test.go
Normal file
20
trees/tree_bench_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package trees_test
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/ensemble"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkRandomForestFit(b *testing.B) {
|
||||
// benchdata.csv contains ../examples/datasets/iris.csv repeated 100 times.
|
||||
data, err := base.ParseCSVToInstances("benchdata.csv", true)
|
||||
if err != nil {
|
||||
b.Fatalf("Cannot load benchdata.csv err:\n%v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
tree := ensemble.NewRandomForest(20, 4)
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.Fit(data)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user