mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
ensemble: merge from v2-instances
This commit is contained in:
parent
2bb7c2de75
commit
2e5deb1476
@ -1,13 +1,12 @@
|
|||||||
/*
|
//
|
||||||
|
//
|
||||||
Ensemble contains classifiers which combine other classifiers.
|
// Ensemble contains classifiers which combine other classifiers.
|
||||||
|
//
|
||||||
RandomForest:
|
// RandomForest:
|
||||||
Generates ForestSize bagged decision trees (currently ID3-based)
|
// Generates ForestSize bagged decision trees (currently ID3-based)
|
||||||
each considering a fixed number of random features.
|
// each considering a fixed number of random features.
|
||||||
|
//
|
||||||
Built on meta.Bagging
|
// Built on meta.Bagging
|
||||||
|
//
|
||||||
*/
|
|
||||||
|
|
||||||
package ensemble
|
package ensemble
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RandomForest classifies instances using an ensemble
|
// RandomForest classifies instances using an ensemble
|
||||||
// of bagged random decision trees
|
// of bagged random decision trees.
|
||||||
type RandomForest struct {
|
type RandomForest struct {
|
||||||
base.BaseClassifier
|
base.BaseClassifier
|
||||||
ForestSize int
|
ForestSize int
|
||||||
@ -18,7 +18,7 @@ type RandomForest struct {
|
|||||||
|
|
||||||
// NewRandomForest generates and return a new random forests
|
// NewRandomForest generates and return a new random forests
|
||||||
// forestSize controls the number of trees that get built
|
// forestSize controls the number of trees that get built
|
||||||
// features controls the number of features used to build each tree
|
// features controls the number of features used to build each tree.
|
||||||
func NewRandomForest(forestSize int, features int) *RandomForest {
|
func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||||
ret := &RandomForest{
|
ret := &RandomForest{
|
||||||
base.BaseClassifier{},
|
base.BaseClassifier{},
|
||||||
@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fit builds the RandomForest on the specified instances
|
// Fit builds the RandomForest on the specified instances
|
||||||
func (f *RandomForest) Fit(on *base.Instances) {
|
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
||||||
f.Model = new(meta.BaggedModel)
|
f.Model = new(meta.BaggedModel)
|
||||||
f.Model.RandomFeatures = f.Features
|
f.Model.RandomFeatures = f.Features
|
||||||
for i := 0; i < f.ForestSize; i++ {
|
for i := 0; i < f.ForestSize; i++ {
|
||||||
@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) {
|
|||||||
f.Model.Fit(on)
|
f.Model.Fit(on)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict generates predictions from a trained RandomForest
|
// Predict generates predictions from a trained RandomForest.
|
||||||
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
|
||||||
return f.Model.Predict(with)
|
return f.Model.Predict(with)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a human-readable representation of this tree.
|
||||||
func (f *RandomForest) String() string {
|
func (f *RandomForest) String() string {
|
||||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||||
}
|
}
|
||||||
|
@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
|
|
||||||
filt := filters.NewChiMergeFilter(trainData, 0.90)
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||||
filt.AddAllNumericAttributes()
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||||
filt.Build()
|
filt.AddAttribute(a)
|
||||||
filt.Run(testData)
|
}
|
||||||
filt.Run(trainData)
|
filt.Train()
|
||||||
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||||
|
|
||||||
|
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
||||||
|
|
||||||
rf := NewRandomForest(10, 3)
|
rf := NewRandomForest(10, 3)
|
||||||
rf.Fit(trainData)
|
rf.Fit(trainData)
|
||||||
predictions := rf.Predict(testData)
|
predictions := rf.Predict(testData)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user