mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
meta: merge from v2-instances
This commit is contained in:
parent
77596a32ed
commit
3b3b23a221
@ -21,23 +21,18 @@ type BaggedModel struct {
|
|||||||
|
|
||||||
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
||||||
// the provided base.Instances.
|
// the provided base.Instances.
|
||||||
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
|
||||||
ret := make([]base.Attribute, 0)
|
ret := make([]base.Attribute, 0)
|
||||||
|
attrs := base.NonClassAttributes(from)
|
||||||
if b.RandomFeatures == 0 {
|
if b.RandomFeatures == 0 {
|
||||||
for j := 0; j < from.Cols; j++ {
|
ret = attrs
|
||||||
attr := from.GetAttr(j)
|
|
||||||
ret = append(ret, attr)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for {
|
for {
|
||||||
if len(ret) >= b.RandomFeatures {
|
if len(ret) >= b.RandomFeatures {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
attrIndex := rand.Intn(from.Cols)
|
attrIndex := rand.Intn(len(attrs))
|
||||||
if attrIndex == from.ClassIndex {
|
attr := attrs[attrIndex]
|
||||||
continue
|
|
||||||
}
|
|
||||||
attr := from.GetAttr(attrIndex)
|
|
||||||
matched := false
|
matched := false
|
||||||
for _, a := range ret {
|
for _, a := range ret {
|
||||||
if a.Equals(attr) {
|
if a.Equals(attr) {
|
||||||
@ -50,7 +45,9 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret = append(ret, from.GetClassAttr())
|
for _, a := range from.AllClassAttributes() {
|
||||||
|
ret = append(ret, a)
|
||||||
|
}
|
||||||
b.lock.Lock()
|
b.lock.Lock()
|
||||||
b.selectedAttributes[model] = ret
|
b.selectedAttributes[model] = ret
|
||||||
b.lock.Unlock()
|
b.lock.Unlock()
|
||||||
@ -60,18 +57,19 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
|||||||
// generatePredictionInstances returns a modified version of the
|
// generatePredictionInstances returns a modified version of the
|
||||||
// requested base.Instances with only the base.Attributes selected
|
// requested base.Instances with only the base.Attributes selected
|
||||||
// for training the model.
|
// for training the model.
|
||||||
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||||
selected := b.selectedAttributes[model]
|
selected := b.selectedAttributes[model]
|
||||||
return from.SelectAttributes(selected)
|
return base.NewInstancesViewFromAttrs(from, selected)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateTrainingInstances generates RandomFeatures number of
|
// generateTrainingInstances generates RandomFeatures number of
|
||||||
// attributes and returns a modified version of base.Instances
|
// attributes and returns a modified version of base.Instances
|
||||||
// for training the model
|
// for training the model
|
||||||
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
|
||||||
insts := from.SampleWithReplacement(from.Rows)
|
_, rows := from.Size()
|
||||||
|
insts := base.SampleWithReplacement(from, rows)
|
||||||
selected := b.generateTrainingAttrs(model, from)
|
selected := b.generateTrainingAttrs(model, from)
|
||||||
return insts.SelectAttributes(selected)
|
return base.NewInstancesViewFromAttrs(insts, selected)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModel adds a base.Classifier to the current model
|
// AddModel adds a base.Classifier to the current model
|
||||||
@ -81,12 +79,12 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
|
|||||||
|
|
||||||
// Fit generates and trains each model on a randomised subset of
|
// Fit generates and trains each model on a randomised subset of
|
||||||
// Instances.
|
// Instances.
|
||||||
func (b *BaggedModel) Fit(from *base.Instances) {
|
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
|
||||||
var wait sync.WaitGroup
|
var wait sync.WaitGroup
|
||||||
b.selectedAttributes = make(map[int][]base.Attribute)
|
b.selectedAttributes = make(map[int][]base.Attribute)
|
||||||
for i, m := range b.Models {
|
for i, m := range b.Models {
|
||||||
wait.Add(1)
|
wait.Add(1)
|
||||||
go func(c base.Classifier, f *base.Instances, model int) {
|
go func(c base.Classifier, f base.FixedDataGrid, model int) {
|
||||||
l := b.generateTrainingInstances(model, f)
|
l := b.generateTrainingInstances(model, f)
|
||||||
c.Fit(l)
|
c.Fit(l)
|
||||||
wait.Done()
|
wait.Done()
|
||||||
@ -100,10 +98,10 @@ func (b *BaggedModel) Fit(from *base.Instances) {
|
|||||||
//
|
//
|
||||||
// IMPORTANT: in the event of a tie, the first class which
|
// IMPORTANT: in the event of a tie, the first class which
|
||||||
// achieved the tie value is output.
|
// achieved the tie value is output.
|
||||||
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||||
n := runtime.NumCPU()
|
n := runtime.NumCPU()
|
||||||
// Channel to receive the results as they come in
|
// Channel to receive the results as they come in
|
||||||
votes := make(chan *base.Instances, n)
|
votes := make(chan base.DataGrid, n)
|
||||||
// Count the votes for each class
|
// Count the votes for each class
|
||||||
voting := make(map[int](map[string]int))
|
voting := make(map[int](map[string]int))
|
||||||
|
|
||||||
@ -111,21 +109,20 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
|||||||
var votingwait sync.WaitGroup
|
var votingwait sync.WaitGroup
|
||||||
votingwait.Add(1)
|
votingwait.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for { // Need to resolve the voting problem
|
||||||
incoming, ok := <-votes
|
incoming, ok := <-votes
|
||||||
if ok {
|
if ok {
|
||||||
// Step through each prediction
|
cSpecs := base.ResolveAllAttributes(incoming, incoming.AllClassAttributes())
|
||||||
for j := 0; j < incoming.Rows; j++ {
|
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
|
||||||
// Check if we've seen this class before...
|
// Check if we've seen this class before...
|
||||||
if _, ok := voting[j]; !ok {
|
if _, ok := voting[predRow]; !ok {
|
||||||
// If we haven't, create an entry
|
// If we haven't, create an entry
|
||||||
voting[j] = make(map[string]int)
|
voting[predRow] = make(map[string]int)
|
||||||
// Continue on the current row
|
// Continue on the current row
|
||||||
j--
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
voting[j][incoming.GetClass(j)]++
|
voting[predRow][base.GetClass(incoming, predRow)]++
|
||||||
}
|
return true, nil
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
votingwait.Done()
|
votingwait.Done()
|
||||||
break
|
break
|
||||||
@ -162,7 +159,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
|||||||
votingwait.Wait() // All the votes are in
|
votingwait.Wait() // All the votes are in
|
||||||
|
|
||||||
// Generate the overall consensus
|
// Generate the overall consensus
|
||||||
ret := from.GeneratePredictionVector()
|
ret := base.GeneratePredictionVector(from)
|
||||||
for i := range voting {
|
for i := range voting {
|
||||||
maxClass := ""
|
maxClass := ""
|
||||||
maxCount := 0
|
maxCount := 0
|
||||||
@ -174,7 +171,7 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
|||||||
maxCount = votes
|
maxCount = votes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret.SetAttrStr(i, 0, maxClass)
|
base.SetClass(ret, i, maxClass)
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
@ -19,16 +19,18 @@ func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
|||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
filt.AddAllNumericAttributes()
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.Build()
|
filt.AddAttribute(a)
|
||||||
filt.Run(inst)
|
}
|
||||||
|
filt.Train()
|
||||||
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
testEnv.ResetTimer()
|
testEnv.ResetTimer()
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
rf.Fit(inst)
|
rf.Fit(instf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,17 +42,19 @@ func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
|||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
filt.AddAllNumericAttributes()
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.Build()
|
filt.AddAttribute(a)
|
||||||
filt.Run(inst)
|
}
|
||||||
|
filt.Train()
|
||||||
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
rf.Fit(inst)
|
rf.Fit(instf)
|
||||||
testEnv.ResetTimer()
|
testEnv.ResetTimer()
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
rf.Predict(inst)
|
rf.Predict(instf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,19 +67,21 @@ func TestRandomForest1(testEnv *testing.T) {
|
|||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
filt.AddAllNumericAttributes()
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.Build()
|
filt.AddAttribute(a)
|
||||||
filt.Run(testData)
|
}
|
||||||
filt.Run(trainData)
|
filt.Train()
|
||||||
|
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
|
||||||
|
testDataf := base.NewLazilyFilteredInstances(testData, filt)
|
||||||
rf := new(BaggedModel)
|
rf := new(BaggedModel)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
rf.AddModel(trees.NewRandomTree(2))
|
rf.AddModel(trees.NewRandomTree(2))
|
||||||
}
|
}
|
||||||
rf.Fit(trainData)
|
rf.Fit(trainDataf)
|
||||||
fmt.Println(rf)
|
fmt.Println(rf)
|
||||||
predictions := rf.Predict(testData)
|
predictions := rf.Predict(testDataf)
|
||||||
fmt.Println(predictions)
|
fmt.Println(predictions)
|
||||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
|
||||||
fmt.Println(confusionMat)
|
fmt.Println(confusionMat)
|
||||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||||
|
13
meta/meta.go
Normal file
13
meta/meta.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
/*
|
||||||
|
|
||||||
|
Meta contains base.Classifier implementations which
|
||||||
|
combine the outputs of others defined elsewhere.
|
||||||
|
|
||||||
|
Bagging:
|
||||||
|
Bootstraps samples of the original training set
|
||||||
|
with a number of selected attributes, and uses
|
||||||
|
that to train an ensemble of models. Predictions
|
||||||
|
are generated via majority voting.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package meta
|
Loading…
x
Reference in New Issue
Block a user