mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
ChiMerge seems to improve accuracy
This commit is contained in:
parent
fdb67a4355
commit
cf165695c8
@ -31,10 +31,11 @@ func NewRandomForest(forestSize int, features int) RandomForest {
|
||||
// Train builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Train(on *base.Instances) {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
f.Model.RandomFeatures = f.Features
|
||||
f.Model.SelectedFeatures = make(map[int][]base.Attribute)
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
f.Model.AddModel(new(trees.RandomTree))
|
||||
tree := new(trees.RandomTree)
|
||||
tree.Rules = new(trees.RandomTreeRule)
|
||||
tree.Attributes = f.Features
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
f.Model.Train(on)
|
||||
}
|
||||
|
@ -43,6 +43,21 @@ func (c *ChiMergeFilter) Build() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddAllNumericAttributes adds every suitable attribute
|
||||
// to the ChiMergeFilter for discretisation
|
||||
func (b *ChiMergeFilter) AddAllNumericAttributes() {
|
||||
for i := 0; i < b.Instances.Cols; i++ {
|
||||
if i == b.Instances.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := b.Instances.GetAttr(i)
|
||||
if attr.GetType() != base.Float64Type {
|
||||
continue
|
||||
}
|
||||
b.Attributes = append(b.Attributes, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Run discretises the set of Instances `on'
|
||||
//
|
||||
// IMPORTANT: ChiMergeFilter discretises in place.
|
||||
|
@ -3,70 +3,20 @@ package meta
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BaggedModels train Classifiers on subsets of the original
|
||||
// Instances and combine the results through voting
|
||||
type BaggedModel struct {
|
||||
base.BaseClassifier
|
||||
Models []base.Classifier
|
||||
SelectedFeatures map[int][]base.Attribute
|
||||
// If this is greater than 0, select up to d features
|
||||
// for feeding into each classifier
|
||||
RandomFeatures int
|
||||
Models []base.Classifier
|
||||
}
|
||||
|
||||
func (b *BaggedModel) generateRandomAttributes(from *base.Instances) []base.Attribute {
|
||||
if b.RandomFeatures > from.GetAttributeCount()-1 {
|
||||
panic("Can't have more random features")
|
||||
}
|
||||
ret := make([]base.Attribute, 0)
|
||||
for {
|
||||
if len(ret) > b.RandomFeatures {
|
||||
break
|
||||
}
|
||||
attrIndex := rand.Intn(from.GetAttributeCount())
|
||||
if attrIndex == from.ClassIndex {
|
||||
continue
|
||||
}
|
||||
matched := false
|
||||
newAttr := from.GetAttr(attrIndex)
|
||||
for _, a := range ret {
|
||||
if a.Equals(newAttr) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
ret = append(ret, newAttr)
|
||||
}
|
||||
}
|
||||
ret = append(ret, from.GetClassAttr())
|
||||
return ret
|
||||
}
|
||||
|
||||
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) ([]base.Attribute, *base.Instances) {
|
||||
|
||||
var attrs []base.Attribute
|
||||
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances {
|
||||
from = from.SampleWithReplacement(from.Rows)
|
||||
|
||||
if b.RandomFeatures > 0 {
|
||||
attrs = b.generateRandomAttributes(from)
|
||||
from = from.SelectAttributes(attrs)
|
||||
} else {
|
||||
attrs = make([]base.Attribute, 0)
|
||||
}
|
||||
|
||||
return attrs, from
|
||||
}
|
||||
|
||||
func (b *BaggedModel) generateTestingInstances(from *base.Instances, model int) *base.Instances {
|
||||
attrs := b.SelectedFeatures[model]
|
||||
return from.SelectAttributes(attrs)
|
||||
return from
|
||||
}
|
||||
|
||||
func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
@ -78,11 +28,9 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
n := runtime.GOMAXPROCS(0)
|
||||
block := make(chan bool, n)
|
||||
for i, m := range b.Models {
|
||||
for _, m := range b.Models {
|
||||
go func(c base.Classifier, f *base.Instances) {
|
||||
a, f := b.generateTrainingInstances(f)
|
||||
b.SelectedFeatures[i] = a
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
f = b.generateTrainingInstances(f)
|
||||
c.Fit(f)
|
||||
block <- true
|
||||
}(m, from)
|
||||
@ -102,9 +50,8 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
// Channel to receive the results as they come in
|
||||
votes := make(chan *base.Instances, n)
|
||||
// Dispatch prediction generation
|
||||
for i, m := range b.Models {
|
||||
for _, m := range b.Models {
|
||||
go func(c base.Classifier, f *base.Instances) {
|
||||
f = b.generateTestingInstances(f, i)
|
||||
p := c.Predict(f)
|
||||
votes <- p
|
||||
}(m, from)
|
||||
|
@ -19,14 +19,12 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewBinningFilter(insts[0], 10)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(insts[1])
|
||||
filt.Run(insts[0])
|
||||
rf := new(BaggedModel)
|
||||
rf.RandomFeatures = 2
|
||||
rf.SelectedFeatures = make(map[int][]base.Attribute)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ func TestRandomTree(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
@ -31,7 +32,7 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewBinningFilter(insts[0], 10)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(insts[0])
|
||||
@ -56,7 +57,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewBinningFilter(insts[0], 10)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(insts[1])
|
||||
|
Loading…
x
Reference in New Issue
Block a user