2014-05-14 14:00:22 +01:00
|
|
|
package ensemble
|
|
|
|
|
|
|
|
import (
|
2014-07-18 13:48:28 +03:00
|
|
|
"fmt"
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
|
|
"github.com/sjwhitworth/golearn/meta"
|
|
|
|
"github.com/sjwhitworth/golearn/trees"
|
2014-05-14 14:00:22 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
// RandomForest classifies instances using an ensemble
|
2014-08-02 16:22:14 +01:00
|
|
|
// of bagged random decision trees.
|
2014-05-14 14:00:22 +01:00
|
|
|
type RandomForest struct {
|
|
|
|
base.BaseClassifier
|
|
|
|
ForestSize int
|
|
|
|
Features int
|
|
|
|
Model *meta.BaggedModel
|
|
|
|
}
|
|
|
|
|
2014-07-18 13:48:28 +03:00
|
|
|
// NewRandomForest generates and return a new random forests
|
2014-05-14 14:00:22 +01:00
|
|
|
// forestSize controls the number of trees that get built
|
2014-08-02 16:22:14 +01:00
|
|
|
// features controls the number of features used to build each tree.
|
2014-05-19 12:42:03 +01:00
|
|
|
func NewRandomForest(forestSize int, features int) *RandomForest {
|
|
|
|
ret := &RandomForest{
|
2014-05-14 14:00:22 +01:00
|
|
|
base.BaseClassifier{},
|
|
|
|
forestSize,
|
|
|
|
features,
|
|
|
|
nil,
|
|
|
|
}
|
|
|
|
return ret
|
|
|
|
}
|
|
|
|
|
2014-07-18 13:48:28 +03:00
|
|
|
// Fit builds the RandomForest on the specified instances
|
2014-08-02 16:22:14 +01:00
|
|
|
func (f *RandomForest) Fit(on base.FixedDataGrid) {
|
2014-08-22 07:39:14 +00:00
|
|
|
numNonClassAttributes := len(base.NonClassAttributes(on))
|
|
|
|
if numNonClassAttributes < f.Features {
|
|
|
|
panic(fmt.Sprintf(
|
|
|
|
"Random forest with %d features cannot fit data grid with %d non-class attributes",
|
|
|
|
f.Features,
|
|
|
|
numNonClassAttributes,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
|
2014-05-14 14:00:22 +01:00
|
|
|
f.Model = new(meta.BaggedModel)
|
2014-05-18 11:49:35 +01:00
|
|
|
f.Model.RandomFeatures = f.Features
|
2014-05-14 14:00:22 +01:00
|
|
|
for i := 0; i < f.ForestSize; i++ {
|
2014-05-18 11:49:35 +01:00
|
|
|
tree := trees.NewID3DecisionTree(0.00)
|
2014-05-17 16:20:56 +01:00
|
|
|
f.Model.AddModel(tree)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
2014-05-17 17:35:10 +01:00
|
|
|
f.Model.Fit(on)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
|
2014-08-02 16:22:14 +01:00
|
|
|
// Predict generates predictions from a trained RandomForest.
|
|
|
|
func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid {
|
2014-05-14 14:00:22 +01:00
|
|
|
return f.Model.Predict(with)
|
|
|
|
}
|
2014-05-19 12:42:03 +01:00
|
|
|
|
2014-08-02 16:22:14 +01:00
|
|
|
// String returns a human-readable representation of this tree.
|
2014-05-19 12:42:03 +01:00
|
|
|
func (f *RandomForest) String() string {
|
|
|
|
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
2014-07-18 13:48:28 +03:00
|
|
|
}
|