1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00

ensemble: merge from v2-instances

This commit is contained in:
Richard Townsend 2014-08-02 16:22:14 +01:00
parent 2bb7c2de75
commit 2e5deb1476
3 changed files with 27 additions and 23 deletions

View File

@ -1,13 +1,12 @@
/* //
//
Ensemble contains classifiers which combine other classifiers. // Ensemble contains classifiers which combine other classifiers.
//
RandomForest: // RandomForest:
Generates ForestSize bagged decision trees (currently ID3-based) // Generates ForestSize bagged decision trees (currently ID3-based)
each considering a fixed number of random features. // each considering a fixed number of random features.
//
Built on meta.Bagging // Built on meta.Bagging
//
*/
package ensemble package ensemble

View File

@ -8,7 +8,7 @@ import (
) )
// RandomForest classifies instances using an ensemble // RandomForest classifies instances using an ensemble
// of bagged random decision trees // of bagged random decision trees.
type RandomForest struct { type RandomForest struct {
base.BaseClassifier base.BaseClassifier
ForestSize int ForestSize int
@ -18,7 +18,7 @@ type RandomForest struct {
// NewRandomForest generates and return a new random forests // NewRandomForest generates and return a new random forests
// forestSize controls the number of trees that get built // forestSize controls the number of trees that get built
// features controls the number of features used to build each tree // features controls the number of features used to build each tree.
func NewRandomForest(forestSize int, features int) *RandomForest { func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{ ret := &RandomForest{
base.BaseClassifier{}, base.BaseClassifier{},
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
} }
// Fit builds the RandomForest on the specified instances // Fit builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on *base.Instances) { func (f *RandomForest) Fit(on base.FixedDataGrid) {
f.Model = new(meta.BaggedModel) f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features f.Model.RandomFeatures = f.Features
for i := 0; i < f.ForestSize; i++ { for i := 0; i < f.ForestSize; i++ {
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
f.Model.Fit(on) f.Model.Fit(on)
} }
// Predict generates predictions from a trained RandomForest // Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with *base.Instances) *base.Instances { func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
return f.Model.Predict(with) return f.Model.Predict(with)
} }
// String returns a human-readable representation of this tree.
func (f *RandomForest) String() string { func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model) return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
} }

View File

@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90) filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes() for _, a := range base.NonClassFloatAttributes(inst) {
filt.Build() filt.AddAttribute(a)
filt.Run(testData) }
filt.Run(trainData) filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3) rf := NewRandomForest(10, 3)
rf.Fit(trainData) rf.Fit(trainData)
predictions := rf.Predict(testData) predictions := rf.Predict(testData)