mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
351 lines
9.8 KiB
Go
351 lines
9.8 KiB
Go
package meta
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math/rand"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// BaggedModel trains base.Classifiers on subsets of the original
|
|
// Instances and combine the results through voting
|
|
type BaggedModel struct {
|
|
base.BaseClassifier
|
|
Models []base.Classifier
|
|
RandomFeatures int
|
|
lock sync.Mutex
|
|
selectedAttributes map[int][]base.Attribute
|
|
fitOn base.FixedDataGrid
|
|
}
|
|
|
|
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
|
// the provided base.Instances.
|
|
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
|
|
ret := make([]base.Attribute, 0)
|
|
attrs := base.NonClassAttributes(from)
|
|
if b.RandomFeatures == 0 {
|
|
ret = attrs
|
|
} else {
|
|
for {
|
|
if len(ret) >= b.RandomFeatures {
|
|
break
|
|
}
|
|
attrIndex := rand.Intn(len(attrs))
|
|
attr := attrs[attrIndex]
|
|
matched := false
|
|
for _, a := range ret {
|
|
if a.Equals(attr) {
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if !matched {
|
|
ret = append(ret, attr)
|
|
}
|
|
}
|
|
}
|
|
for _, a := range from.AllClassAttributes() {
|
|
ret = append(ret, a)
|
|
}
|
|
b.lock.Lock()
|
|
b.selectedAttributes[model] = ret
|
|
b.lock.Unlock()
|
|
return ret
|
|
}
|
|
|
|
// generatePredictionInstances returns a modified version of the
|
|
// requested base.Instances with only the base.Attributes selected
|
|
// for training the model.
|
|
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
|
selected := b.selectedAttributes[model]
|
|
return base.NewInstancesViewFromAttrs(from, selected)
|
|
}
|
|
|
|
// generateTrainingInstances generates RandomFeatures number of
|
|
// attributes and returns a modified version of base.Instances
|
|
// for training the model
|
|
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
|
_, rows := from.Size()
|
|
insts := base.SampleWithReplacement(from, rows)
|
|
selected := b.generateTrainingAttrs(model, from)
|
|
return base.NewInstancesViewFromAttrs(insts, selected)
|
|
}
|
|
|
|
// AddModel adds a base.Classifier to the current model
|
|
func (b *BaggedModel) AddModel(m base.Classifier) {
|
|
b.Models = append(b.Models, m)
|
|
}
|
|
|
|
// Fit generates and trains each model on a randomised subset of
|
|
// Instances.
|
|
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
|
|
var wait sync.WaitGroup
|
|
b.selectedAttributes = make(map[int][]base.Attribute)
|
|
for i, m := range b.Models {
|
|
wait.Add(1)
|
|
go func(c base.Classifier, f base.FixedDataGrid, model int) {
|
|
l := b.generateTrainingInstances(model, f)
|
|
c.Fit(l)
|
|
wait.Done()
|
|
}(m, from, i)
|
|
}
|
|
wait.Wait()
|
|
b.fitOn = base.NewStructuralCopy(from)
|
|
}
|
|
|
|
// Predict gathers predictions from all the classifiers
|
|
// and outputs the most common (majority) class
|
|
//
|
|
// IMPORTANT: in the event of a tie, the first class which
|
|
// achieved the tie value is output.
|
|
func (b *BaggedModel) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
n := runtime.NumCPU()
|
|
// Channel to receive the results as they come in
|
|
votes := make(chan base.DataGrid, n)
|
|
// Count the votes for each class
|
|
voting := make(map[int](map[string]int))
|
|
|
|
// Create a goroutine to collect the votes
|
|
var votingwait sync.WaitGroup
|
|
votingwait.Add(1)
|
|
go func() {
|
|
for { // Need to resolve the voting problem
|
|
incoming, ok := <-votes
|
|
if ok {
|
|
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
|
|
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
|
// Check if we've seen this class before...
|
|
if _, ok := voting[predRow]; !ok {
|
|
// If we haven't, create an entry
|
|
voting[predRow] = make(map[string]int)
|
|
// Continue on the current row
|
|
}
|
|
voting[predRow][base.GetClass(incoming, predRow)]++
|
|
return true, nil
|
|
})
|
|
} else {
|
|
votingwait.Done()
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Create workers to process the predictions
|
|
processpipe := make(chan int, n)
|
|
var processwait sync.WaitGroup
|
|
for i := 0; i < n; i++ {
|
|
processwait.Add(1)
|
|
go func() {
|
|
for {
|
|
if i, ok := <-processpipe; ok {
|
|
c := b.Models[i]
|
|
l := b.generatePredictionInstances(i, from)
|
|
v, _ := c.Predict(l)
|
|
votes <- v
|
|
} else {
|
|
processwait.Done()
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Send all the models to the workers for prediction
|
|
for i := range b.Models {
|
|
processpipe <- i
|
|
}
|
|
close(processpipe) // Finished sending models to be predicted
|
|
processwait.Wait() // Predictors all finished processing
|
|
close(votes) // Close the vote channel and allow it to drain
|
|
votingwait.Wait() // All the votes are in
|
|
|
|
// Generate the overall consensus
|
|
ret := base.GeneratePredictionVector(from)
|
|
for i := range voting {
|
|
maxClass := ""
|
|
maxCount := 0
|
|
// Find the most popular class
|
|
for c := range voting[i] {
|
|
votes := voting[i][c]
|
|
if votes > maxCount {
|
|
maxClass = c
|
|
maxCount = votes
|
|
}
|
|
}
|
|
base.SetClass(ret, i, maxClass)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// String returns a human-readable representation of the
|
|
// BaggedModel and everything it contains
|
|
func (b *BaggedModel) String() string {
|
|
children := make([]string, 0)
|
|
for i, m := range b.Models {
|
|
children = append(children, fmt.Sprintf("%d: %s", i, m))
|
|
}
|
|
return fmt.Sprintf("BaggedModel(\n%s)", strings.Join(children, "\n\t"))
|
|
}
|
|
|
|
// GetMetadata returns required serialization information for this classifier
|
|
func (b *BaggedModel) GetMetadata() base.ClassifierMetadataV1 {
|
|
|
|
return base.ClassifierMetadataV1{
|
|
FormatVersion: 1,
|
|
ClassifierName: "BaggedModel",
|
|
ClassifierVersion: "1.0",
|
|
ClassifierMetadata: nil,
|
|
}
|
|
|
|
}
|
|
|
|
func (b *BaggedModel) Save(filePath string) error {
|
|
writer, err := base.CreateSerializedClassifierStub(filePath, b.GetMetadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = b.SaveWithPrefix(writer, "")
|
|
writer.Close()
|
|
return err
|
|
}
|
|
|
|
func (b *BaggedModel) Load(filePath string) error {
|
|
reader, err := base.ReadSerializedClassifierStub(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = b.LoadWithPrefix(reader, "")
|
|
reader.Close()
|
|
return err
|
|
}
|
|
|
|
/* type BaggedModel struct {
|
|
base.BaseClassifier
|
|
Models []base.Classifier
|
|
RandomFeatures int
|
|
lock sync.Mutex
|
|
selectedAttributes map[int][]base.Attribute, always RandomFeatures in length
|
|
}*/
|
|
func (b *BaggedModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
|
b.lock.Lock()
|
|
defer b.lock.Unlock()
|
|
|
|
pI := func(n string, i int) string {
|
|
return writer.Prefix(prefix, writer.Prefix(n, fmt.Sprintf("%d", i)))
|
|
}
|
|
|
|
// Export the number of random features
|
|
randomFeaturesKey := writer.Prefix(prefix, "NUM_RANDOM_FEATURES")
|
|
err := writer.WriteU64ForKey(randomFeaturesKey, uint64(b.RandomFeatures))
|
|
if err != nil {
|
|
return base.DescribeError("Can't write NUM_RANDOM_FEATURES", err)
|
|
}
|
|
|
|
// Write the number of classifiers
|
|
classifiersKey := writer.Prefix(prefix, "NUM_CLASSIFIERS")
|
|
err = writer.WriteU64ForKey(classifiersKey, uint64(len(b.Models)))
|
|
if err != nil {
|
|
return base.DescribeError("Can't write NUM_CLASSIFIERS", err)
|
|
}
|
|
|
|
// Save the classifiers
|
|
for i, c := range b.Models {
|
|
clsPrefix := fmt.Sprintf("%s/", pI("CLASSIFIERS", i))
|
|
err = c.SaveWithPrefix(writer, clsPrefix)
|
|
if err != nil {
|
|
return base.FormatError(err, "Can't save classifier %d", i)
|
|
}
|
|
}
|
|
|
|
// Save the instances template
|
|
err = writer.WriteInstancesForKey(writer.Prefix(prefix, "REFERENCE_INSTANCES"), b.fitOn, false)
|
|
if err != nil {
|
|
return base.DescribeError("Can't write REFERENCE_INSTANCES", err)
|
|
}
|
|
|
|
// Save the selectedAttributes map
|
|
selectedAttributesKey := writer.Prefix(prefix, "SELECTED_ATTRIBUTES")
|
|
ser := make(map[int][][]byte)
|
|
for key := range b.selectedAttributes {
|
|
ser[key] = make([][]byte, 0)
|
|
for _, a := range b.selectedAttributes[key] {
|
|
bytes, err := base.SerializeAttribute(a)
|
|
if err != nil {
|
|
return base.DescribeError("Can't serialize Attribute", err)
|
|
}
|
|
ser[key] = append(ser[key], bytes)
|
|
}
|
|
}
|
|
|
|
err = writer.WriteJSONForKey(selectedAttributesKey, ser)
|
|
if err != nil {
|
|
return base.DescribeError("Can't write selected attributes map", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Remember: have to add the models before you use this.
|
|
func (b *BaggedModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
|
b.lock.Lock()
|
|
defer b.lock.Unlock()
|
|
|
|
pI := func(n string, i int) string {
|
|
return reader.Prefix(prefix, reader.Prefix(n, fmt.Sprintf("%d", i)))
|
|
}
|
|
|
|
// Read the essential info
|
|
randomFeaturesKey := reader.Prefix(prefix, "NUM_RANDOM_FEATURES")
|
|
randomFeatures, err := reader.GetU64ForKey(randomFeaturesKey)
|
|
if err != nil {
|
|
return base.DescribeError("Can't read NUM_RANDOM_FEATURES", err)
|
|
}
|
|
|
|
b.RandomFeatures = int(randomFeatures)
|
|
|
|
// Reload the classifiers
|
|
for i, m := range b.Models {
|
|
clsPrefix := fmt.Sprintf("%s/", pI("CLASSIFIERS", i))
|
|
err := m.LoadWithPrefix(reader, clsPrefix)
|
|
if err != nil {
|
|
return base.DescribeError("Can't read classifier", err)
|
|
}
|
|
}
|
|
|
|
// Load the instances template
|
|
tmp, err := reader.GetInstancesForKey(reader.Prefix(prefix, "REFERENCE_INSTANCES"))
|
|
if err != nil {
|
|
return base.DescribeError("Can't read REFERENCE_INSTACES", err)
|
|
}
|
|
b.fitOn = tmp
|
|
|
|
// Reload the selected attributes
|
|
selectedAttributesKey := reader.Prefix(prefix, "SELECTED_ATTRIBUTES")
|
|
ser := make(map[int][][]byte)
|
|
err = reader.GetJSONForKey(selectedAttributesKey, &ser)
|
|
if err != nil {
|
|
return base.DescribeError("Can't reload selected attributes", err)
|
|
}
|
|
|
|
b.selectedAttributes = make(map[int][]base.Attribute)
|
|
|
|
for key := range ser {
|
|
b.selectedAttributes[key] = make([]base.Attribute, len(ser[key]))
|
|
for i, attrBytes := range ser[key] {
|
|
attr, err := base.DeserializeAttribute(attrBytes)
|
|
if err != nil {
|
|
return base.DescribeError("Can't deserialize attribute", err)
|
|
}
|
|
attrNew, err := base.ReplaceDeserializedAttributeWithVersionFromInstances(attr, b.fitOn)
|
|
if err != nil {
|
|
return base.DescribeError("Can't replace attribute", err)
|
|
}
|
|
b.selectedAttributes[key][i] = attrNew
|
|
}
|
|
}
|
|
return nil
|
|
}
|