1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

ensemble: merge from v2-instances

This commit is contained in:
Richard Townsend 2014-08-02 16:22:14 +01:00
parent 2bb7c2de75
commit 2e5deb1476
3 changed files with 27 additions and 23 deletions

View File

@ -1,13 +1,12 @@
/*
//
//
// Ensemble contains classifiers which combine other classifiers.
//
// RandomForest:
// Generates ForestSize bagged decision trees (currently ID3-based)
// each considering a fixed number of random features.
//
// Built on meta.Bagging
//
Ensemble contains classifiers which combine other classifiers.
RandomForest:
Generates ForestSize bagged decision trees (currently ID3-based)
each considering a fixed number of random features.
Built on meta.Bagging
*/
package ensemble
package ensemble

View File

@ -8,7 +8,7 @@ import (
)
// RandomForest classifies instances using an ensemble
// of bagged random decision trees
// of bagged random decision trees.
type RandomForest struct {
base.BaseClassifier
ForestSize int
@ -18,7 +18,7 @@ type RandomForest struct {
// NewRandomForest generates and return a new random forests
// forestSize controls the number of trees that get built
// features controls the number of features used to build each tree
// features controls the number of features used to build each tree.
func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{
base.BaseClassifier{},
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
}
// Fit builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on *base.Instances) {
func (f *RandomForest) Fit(on base.FixedDataGrid) {
f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features
for i := 0; i < f.ForestSize; i++ {
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
f.Model.Fit(on)
}
// Predict generates predictions from a trained RandomForest
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
// Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
return f.Model.Predict(with)
}
// String returns a human-readable representation of this tree.
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}

View File

@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)