mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
meta: merge from v2-instances
This commit is contained in:
parent
77596a32ed
commit
3b3b23a221
@ -21,23 +21,18 @@ type BaggedModel struct {
|
||||
|
||||
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
||||
// the provided base.Instances.
|
||||
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
||||
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
|
||||
ret := make([]base.Attribute, 0)
|
||||
attrs := base.NonClassAttributes(from)
|
||||
if b.RandomFeatures == 0 {
|
||||
for j := 0; j < from.Cols; j++ {
|
||||
attr := from.GetAttr(j)
|
||||
ret = append(ret, attr)
|
||||
}
|
||||
ret = attrs
|
||||
} else {
|
||||
for {
|
||||
if len(ret) >= b.RandomFeatures {
|
||||
break
|
||||
}
|
||||
attrIndex := rand.Intn(from.Cols)
|
||||
if attrIndex == from.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := from.GetAttr(attrIndex)
|
||||
attrIndex := rand.Intn(len(attrs))
|
||||
attr := attrs[attrIndex]
|
||||
matched := false
|
||||
for _, a := range ret {
|
||||
if a.Equals(attr) {
|
||||
@ -50,7 +45,9 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
||||
}
|
||||
}
|
||||
}
|
||||
ret = append(ret, from.GetClassAttr())
|
||||
for _, a := range from.AllClassAttributes() {
|
||||
ret = append(ret, a)
|
||||
}
|
||||
b.lock.Lock()
|
||||
b.selectedAttributes[model] = ret
|
||||
b.lock.Unlock()
|
||||
@ -60,18 +57,19 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
||||
// generatePredictionInstances returns a modified version of the
|
||||
// requested base.Instances with only the base.Attributes selected
|
||||
// for training the model.
|
||||
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
||||
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||
selected := b.selectedAttributes[model]
|
||||
return from.SelectAttributes(selected)
|
||||
return base.NewInstancesViewFromAttrs(from, selected)
|
||||
}
|
||||
|
||||
// generateTrainingInstances generates RandomFeatures number of
|
||||
// attributes and returns a modified version of base.Instances
|
||||
// for training the model
|
||||
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
||||
insts := from.SampleWithReplacement(from.Rows)
|
||||
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||
_, rows := from.Size()
|
||||
insts := base.SampleWithReplacement(from, rows)
|
||||
selected := b.generateTrainingAttrs(model, from)
|
||||
return insts.SelectAttributes(selected)
|
||||
return base.NewInstancesViewFromAttrs(insts, selected)
|
||||
}
|
||||
|
||||
// AddModel adds a base.Classifier to the current model
|
||||
@ -81,12 +79,12 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
|
||||
// Fit generates and trains each model on a randomised subset of
|
||||
// Instances.
|
||||
func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
|
||||
var wait sync.WaitGroup
|
||||
b.selectedAttributes = make(map[int][]base.Attribute)
|
||||
for i, m := range b.Models {
|
||||
wait.Add(1)
|
||||
go func(c base.Classifier, f *base.Instances, model int) {
|
||||
go func(c base.Classifier, f base.FixedDataGrid, model int) {
|
||||
l := b.generateTrainingInstances(model, f)
|
||||
c.Fit(l)
|
||||
wait.Done()
|
||||
@ -100,10 +98,10 @@ func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
//
|
||||
// IMPORTANT: in the event of a tie, the first class which
|
||||
// achieved the tie value is output.
|
||||
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
n := runtime.NumCPU()
|
||||
// Channel to receive the results as they come in
|
||||
votes := make(chan *base.Instances, n)
|
||||
votes := make(chan base.DataGrid, n)
|
||||
// Count the votes for each class
|
||||
voting := make(map[int](map[string]int))
|
||||
|
||||
@ -111,21 +109,20 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
var votingwait sync.WaitGroup
|
||||
votingwait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
for { // Need to resolve the voting problem
|
||||
incoming, ok := <-votes
|
||||
if ok {
|
||||
// Step through each prediction
|
||||
for j := 0; j < incoming.Rows; j++ {
|
||||
cSpecs := base.ResolveAllAttributes(incoming, incoming.AllClassAttributes())
|
||||
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
||||
// Check if we've seen this class before...
|
||||
if _, ok := voting[j]; !ok {
|
||||
if _, ok := voting[predRow]; !ok {
|
||||
// If we haven't, create an entry
|
||||
voting[j] = make(map[string]int)
|
||||
voting[predRow] = make(map[string]int)
|
||||
// Continue on the current row
|
||||
j--
|
||||
continue
|
||||
}
|
||||
voting[j][incoming.GetClass(j)]++
|
||||
}
|
||||
voting[predRow][base.GetClass(incoming, predRow)]++
|
||||
return true, nil
|
||||
})
|
||||
} else {
|
||||
votingwait.Done()
|
||||
break
|
||||
@ -162,7 +159,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
votingwait.Wait() // All the votes are in
|
||||
|
||||
// Generate the overall consensus
|
||||
ret := from.GeneratePredictionVector()
|
||||
ret := base.GeneratePredictionVector(from)
|
||||
for i := range voting {
|
||||
maxClass := ""
|
||||
maxCount := 0
|
||||
@ -174,7 +171,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
maxCount = votes
|
||||
}
|
||||
}
|
||||
ret.SetAttrStr(i, 0, maxClass)
|
||||
base.SetClass(ret, i, maxClass)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -19,16 +19,18 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Fit(inst)
|
||||
rf.Fit(instf)
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,17 +42,19 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(inst)
|
||||
rf.Fit(instf)
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Predict(inst)
|
||||
rf.Predict(instf)
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,19 +67,21 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(trainData)
|
||||
rf.Fit(trainDataf)
|
||||
fmt.Println(rf)
|
||||
predictions := rf.Predict(testData)
|
||||
predictions := rf.Predict(testDataf)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
|
13
meta/meta.go
Normal file
13
meta/meta.go
Normal file
@ -0,0 +1,13 @@
|
||||
/*
|
||||
|
||||
Meta contains base.Classifier implementations which
|
||||
combine the outputs of others defined elsewhere.
|
||||
|
||||
Bagging:
|
||||
Bootstraps samples of the original training set
|
||||
with a number of selected attributes, and uses
|
||||
that to train an ensemble of models. Predictions
|
||||
are generated via majority voting.
|
||||
*/
|
||||
|
||||
package meta
|
Loading…
x
Reference in New Issue
Block a user