1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/meta/bagging.go

104 lines
2.6 KiB
Go
Raw Normal View History

package meta
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"runtime"
"strings"
)
// BaggedModels train Classifiers on subsets of the original
// Instances and combine the results through voting
type BaggedModel struct {
base.BaseClassifier
2014-05-17 16:20:56 +01:00
Models []base.Classifier
}
2014-05-17 16:20:56 +01:00
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances {
2014-05-17 17:35:10 +01:00
return from.SampleWithReplacement(from.Rows)
}
2014-05-17 17:35:10 +01:00
// AddModel adds a base.Classifier to the current model
func (b *BaggedModel) AddModel(m base.Classifier) {
b.Models = append(b.Models, m)
}
// Train generates and trains each model on a randomised subset of
// Instances.
func (b *BaggedModel) Fit(from *base.Instances) {
n := runtime.GOMAXPROCS(0)
block := make(chan bool, n)
2014-05-17 16:20:56 +01:00
for _, m := range b.Models {
go func(c base.Classifier, f *base.Instances) {
2014-05-17 16:20:56 +01:00
f = b.generateTrainingInstances(f)
c.Fit(f)
block <- true
}(m, from)
}
for i := 0; i < len(b.Models); i++ {
<-block
}
}
// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
n := runtime.GOMAXPROCS(0)
// Channel to receive the results as they come in
votes := make(chan *base.Instances, n)
// Dispatch prediction generation
2014-05-17 16:20:56 +01:00
for _, m := range b.Models {
go func(c base.Classifier, f *base.Instances) {
p := c.Predict(f)
votes <- p
}(m, from)
}
// Count the votes for each class
voting := make(map[int](map[string]int))
for _ = range b.Models { // Have to do this - deadlocks otherwise
incoming := <-votes
// Step through each prediction
for j := 0; j < incoming.Rows; j++ {
// Check if we've seen this class before...
if _, ok := voting[j]; !ok {
// If we haven't, create an entry
voting[j] = make(map[string]int)
// Continue on the current row
j--
continue
}
voting[j][incoming.GetClass(j)]++
}
}
// Generate the overall consensus
ret := from.GeneratePredictionVector()
for i := range voting {
maxClass := ""
maxCount := 0
// Find the most popular class
for c := range voting[i] {
votes := voting[i][c]
if votes > maxCount {
maxClass = c
maxCount = votes
}
}
ret.SetAttrStr(i, 0, maxClass)
}
return ret
}
// String returns a human-readable representation of the
// BaggedModel and everything it contains
func (b *BaggedModel) String() string {
children := make([]string, 0)
for i, m := range b.Models {
children = append(children, fmt.Sprintf("%d: %s", i, m))
}
return fmt.Sprintf("BaggedModel(\n%s)", strings.Join(children, "\n\t"))
}