mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
ensemble: merge from v2-instances
This commit is contained in:
parent
2bb7c2de75
commit
2e5deb1476
@ -1,13 +1,12 @@
|
||||
/*
|
||||
//
|
||||
//
|
||||
// Ensemble contains classifiers which combine other classifiers.
|
||||
//
|
||||
// RandomForest:
|
||||
// Generates ForestSize bagged decision trees (currently ID3-based)
|
||||
// each considering a fixed number of random features.
|
||||
//
|
||||
// Built on meta.Bagging
|
||||
//
|
||||
|
||||
Ensemble contains classifiers which combine other classifiers.
|
||||
|
||||
RandomForest:
|
||||
Generates ForestSize bagged decision trees (currently ID3-based)
|
||||
each considering a fixed number of random features.
|
||||
|
||||
Built on meta.Bagging
|
||||
|
||||
*/
|
||||
|
||||
package ensemble
|
||||
package ensemble
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// RandomForest classifies instances using an ensemble
|
||||
// of bagged random decision trees
|
||||
// of bagged random decision trees.
|
||||
type RandomForest struct {
|
||||
base.BaseClassifier
|
||||
ForestSize int
|
||||
@ -18,7 +18,7 @@ type RandomForest struct {
|
||||
|
||||
// NewRandomForest generates and return a new random forests
|
||||
// forestSize controls the number of trees that get built
|
||||
// features controls the number of features used to build each tree
|
||||
// features controls the number of features used to build each tree.
|
||||
func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
ret := &RandomForest{
|
||||
base.BaseClassifier{},
|
||||
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
}
|
||||
|
||||
// Fit builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Fit(on *base.Instances) {
|
||||
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
f.Model.RandomFeatures = f.Features
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
|
||||
f.Model.Fit(on)
|
||||
}
|
||||
|
||||
// Predict generates predictions from a trained RandomForest
|
||||
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
||||
// Predict generates predictions from a trained RandomForest.
|
||||
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
|
||||
return f.Model.Predict(with)
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of this tree.
|
||||
func (f *RandomForest) String() string {
|
||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||
}
|
||||
|
@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
|
||||
filt := filters.NewChiMergeFilter(trainData, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
||||
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(trainData)
|
||||
predictions := rf.Predict(testData)
|
||||
|
Loading…
x
Reference in New Issue
Block a user