2014-05-14 14:00:22 +01:00
|
|
|
package ensemble
|
|
|
|
|
|
|
|
import (
|
|
|
|
base "github.com/sjwhitworth/golearn/base"
|
|
|
|
meta "github.com/sjwhitworth/golearn/meta"
|
|
|
|
trees "github.com/sjwhitworth/golearn/trees"
|
2014-05-19 12:42:03 +01:00
|
|
|
"fmt"
|
2014-05-14 14:00:22 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
// RandomForest classifies instances using an ensemble
|
|
|
|
// of bagged random decision trees
|
|
|
|
type RandomForest struct {
|
|
|
|
base.BaseClassifier
|
|
|
|
ForestSize int
|
|
|
|
Features int
|
|
|
|
Model *meta.BaggedModel
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewRandomForests generates and return a new random forests
|
|
|
|
// forestSize controls the number of trees that get built
|
|
|
|
// features controls the number of features used to build each tree
|
2014-05-19 12:42:03 +01:00
|
|
|
func NewRandomForest(forestSize int, features int) *RandomForest {
|
|
|
|
ret := &RandomForest{
|
2014-05-14 14:00:22 +01:00
|
|
|
base.BaseClassifier{},
|
|
|
|
forestSize,
|
|
|
|
features,
|
|
|
|
nil,
|
|
|
|
}
|
|
|
|
return ret
|
|
|
|
}
|
|
|
|
|
|
|
|
// Train builds the RandomForest on the specified instances
|
2014-05-17 17:35:10 +01:00
|
|
|
func (f *RandomForest) Fit(on *base.Instances) {
|
2014-05-14 14:00:22 +01:00
|
|
|
f.Model = new(meta.BaggedModel)
|
2014-05-18 11:49:35 +01:00
|
|
|
f.Model.RandomFeatures = f.Features
|
2014-05-14 14:00:22 +01:00
|
|
|
for i := 0; i < f.ForestSize; i++ {
|
2014-05-18 11:49:35 +01:00
|
|
|
tree := trees.NewID3DecisionTree(0.00)
|
2014-05-17 16:20:56 +01:00
|
|
|
f.Model.AddModel(tree)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
2014-05-17 17:35:10 +01:00
|
|
|
f.Model.Fit(on)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Predict generates predictions from a trained RandomForest
|
|
|
|
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
|
|
|
return f.Model.Predict(with)
|
|
|
|
}
|
2014-05-19 12:42:03 +01:00
|
|
|
|
|
|
|
func (f *RandomForest) String() string {
|
|
|
|
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
|
|
|
}
|