1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

ChiMerge seems to improve accuracy

This commit is contained in:
Richard Townsend 2014-05-17 16:20:56 +01:00
parent fdb67a4355
commit cf165695c8
5 changed files with 30 additions and 68 deletions

View File

@ -31,10 +31,11 @@ func NewRandomForest(forestSize int, features int) RandomForest {
// Train builds the RandomForest on the specified instances
func (f *RandomForest) Train(on *base.Instances) {
f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features
f.Model.SelectedFeatures = make(map[int][]base.Attribute)
for i := 0; i < f.ForestSize; i++ {
f.Model.AddModel(new(trees.RandomTree))
tree := new(trees.RandomTree)
tree.Rules = new(trees.RandomTreeRule)
tree.Attributes = f.Features
f.Model.AddModel(tree)
}
f.Model.Train(on)
}

View File

@ -43,6 +43,21 @@ func (c *ChiMergeFilter) Build() {
}
}
// AddAllNumericAttributes adds every suitable attribute
// to the ChiMergeFilter for discretisation
func (b *ChiMergeFilter) AddAllNumericAttributes() {
for i := 0; i < b.Instances.Cols; i++ {
if i == b.Instances.ClassIndex {
continue
}
attr := b.Instances.GetAttr(i)
if attr.GetType() != base.Float64Type {
continue
}
b.Attributes = append(b.Attributes, i)
}
}
// Run discretises the set of Instances `on'
//
// IMPORTANT: ChiMergeFilter discretises in place.

View File

@ -3,70 +3,20 @@ package meta
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math/rand"
"runtime"
"strings"
"time"
)
// BaggedModels train Classifiers on subsets of the original
// Instances and combine the results through voting
type BaggedModel struct {
base.BaseClassifier
Models []base.Classifier
SelectedFeatures map[int][]base.Attribute
// If this is greater than 0, select up to d features
// for feeding into each classifier
RandomFeatures int
Models []base.Classifier
}
func (b *BaggedModel) generateRandomAttributes(from *base.Instances) []base.Attribute {
if b.RandomFeatures > from.GetAttributeCount()-1 {
panic("Can't have more random features")
}
ret := make([]base.Attribute, 0)
for {
if len(ret) > b.RandomFeatures {
break
}
attrIndex := rand.Intn(from.GetAttributeCount())
if attrIndex == from.ClassIndex {
continue
}
matched := false
newAttr := from.GetAttr(attrIndex)
for _, a := range ret {
if a.Equals(newAttr) {
matched = true
break
}
}
if !matched {
ret = append(ret, newAttr)
}
}
ret = append(ret, from.GetClassAttr())
return ret
}
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) ([]base.Attribute, *base.Instances) {
var attrs []base.Attribute
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances {
from = from.SampleWithReplacement(from.Rows)
if b.RandomFeatures > 0 {
attrs = b.generateRandomAttributes(from)
from = from.SelectAttributes(attrs)
} else {
attrs = make([]base.Attribute, 0)
}
return attrs, from
}
func (b *BaggedModel) generateTestingInstances(from *base.Instances, model int) *base.Instances {
attrs := b.SelectedFeatures[model]
return from.SelectAttributes(attrs)
return from
}
func (b *BaggedModel) AddModel(m base.Classifier) {
@ -78,11 +28,9 @@ func (b *BaggedModel) AddModel(m base.Classifier) {
func (b *BaggedModel) Fit(from *base.Instances) {
n := runtime.GOMAXPROCS(0)
block := make(chan bool, n)
for i, m := range b.Models {
for _, m := range b.Models {
go func(c base.Classifier, f *base.Instances) {
a, f := b.generateTrainingInstances(f)
b.SelectedFeatures[i] = a
rand.Seed(time.Now().UnixNano())
f = b.generateTrainingInstances(f)
c.Fit(f)
block <- true
}(m, from)
@ -102,9 +50,8 @@ func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
// Channel to receive the results as they come in
votes := make(chan *base.Instances, n)
// Dispatch prediction generation
for i, m := range b.Models {
for _, m := range b.Models {
go func(c base.Classifier, f *base.Instances) {
f = b.generateTestingInstances(f, i)
p := c.Predict(f)
votes <- p
}(m, from)

View File

@ -19,14 +19,12 @@ func TestRandomForest1(testEnv *testing.T) {
rand.Seed(time.Now().UnixNano())
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewBinningFilter(insts[0], 10)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
rf := new(BaggedModel)
rf.RandomFeatures = 2
rf.SelectedFeatures = make(map[int][]base.Attribute)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}

View File

@ -14,7 +14,8 @@ func TestRandomTree(testEnv *testing.T) {
if err != nil {
panic(err)
}
filt := filters.NewBinningFilter(inst, 10)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
@ -31,7 +32,7 @@ func TestRandomTreeClassification(testEnv *testing.T) {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewBinningFilter(insts[0], 10)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[0])
@ -56,7 +57,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewBinningFilter(insts[0], 10)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(insts[1])