1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

trees: speed-up training

Avoid quadratic loop in getNumericAttributeEntropy.
We don't need to recalculate whole distribution for each split,
just move changed values. Also use array of slices instead of
map of maps of strings to avoid map overhead.

For our case I see time reductions from 100+ hours to 50 minutes.
I've added benchmark with synthetic data (iris.csv repeated 100 times)
and it also shows a nice improvement:

name               old time/op  new time/op  delta
RandomForestFit-8    117s ± 4%      0s ± 1%  -99.61%  (p=0.001 n=5+10)

0 is a rounding quirk of benchstat, it should be closer to 0.5s:

name               time/op
RandomForestFit-8  460ms ± 1%
This commit is contained in:
Ilya Tocar 2018-05-01 16:25:32 -05:00
parent 623af61265
commit 676f69a426
3 changed files with 15081 additions and 22 deletions

15000
trees/benchdata.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -83,7 +83,7 @@ func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttri
type numericSplitRef struct {
val float64
class string
class int
}
type splitVec []numericSplitRef
@ -103,40 +103,52 @@ func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute)
// Build sortable vector
_, rows := f.Size()
refs := make([]numericSplitRef, rows)
numClasses := 0
class2Int := make(map[string]int)
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
cls := base.GetClass(f, row)
i, ok := class2Int[cls]
if !ok {
i = numClasses
class2Int[cls] = i
numClasses++
}
v := base.UnpackBytesToFloat(val[0])
refs[row] = numericSplitRef{v, cls}
refs[row] = numericSplitRef{v, i}
return true, nil
})
// Sort
sort.Sort(splitVec(refs))
generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
presplit := make(map[string]int)
postplit := make(map[string]int)
for _, i := range refs {
if i.val < val {
presplit[i.class]++
} else {
postplit[i.class]++
}
}
ret := make(map[string]map[string]int)
ret["0"] = presplit
ret["1"] = postplit
return ret
}
minSplitEntropy := math.Inf(1)
minSplitVal := math.Inf(1)
prevVal := math.NaN()
prevInd := 0
splitDist := [2][]int{make([]int, numClasses), make([]int, numClasses)}
// Before first split all refs are not smaller than val
for _, x := range refs {
splitDist[1][x.class]++
}
// Consider each possible function
for i := 0; i < len(refs)-1; i++ {
for i := 0; i < len(refs)-1; {
val := refs[i].val + refs[i+1].val
val /= 2
splitDist := generateCandidateSplitDistribution(val)
splitEntropy := getSplitEntropy(splitDist)
if val == prevVal {
i++
continue
}
// refs is sorted, so we only need to update values that are
// bigger than prevVal, but are lower than val
for j := prevInd; j < len(refs) && refs[j].val < val; j++ {
splitDist[0][refs[j].class]++
splitDist[1][refs[j].class]--
i++
prevInd++
}
prevVal = val
splitEntropy := getSplitEntropyFast(splitDist)
if splitEntropy < minSplitEntropy {
minSplitEntropy = splitEntropy
minSplitVal = val
@ -146,6 +158,33 @@ func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute)
return minSplitEntropy, minSplitVal
}
// getSplitEntropyFast determines the entropy of the target
// class distribution after splitting on an base.Attribute.
// It is similar to getSplitEntropy, but accepts array of slices,
// to avoid map access overhead.
func getSplitEntropyFast(s [2][]int) float64 {
ret := 0.0
count := 0
for a := range s {
for c := range s[a] {
count += s[a][c]
}
}
for a := range s {
total := 0.0
for c := range s[a] {
total += float64(s[a][c])
}
for c := range s[a] {
if s[a][c] != 0 {
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
}
}
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
}
return ret
}
// getSplitEntropy determines the entropy of the target
// class distribution after splitting on an base.Attribute
func getSplitEntropy(s map[string]map[string]int) float64 {

20
trees/tree_bench_test.go Normal file
View File

@ -0,0 +1,20 @@
package trees_test
import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/ensemble"
"testing"
)
func BenchmarkRandomForestFit(b *testing.B) {
// benchdata.csv contains ../examples/datasets/iris.csv repeated 100 times.
data, err := base.ParseCSVToInstances("benchdata.csv", true)
if err != nil {
b.Fatalf("Cannot load benchdata.csv err:\n%v", err)
}
b.ResetTimer()
tree := ensemble.NewRandomForest(20, 4)
for i := 0; i < b.N; i++ {
tree.Fit(data)
}
}