1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/ensemble/randomforest.go
2021-01-10 00:56:45 -03:00

106 lines
2.7 KiB
Go

package ensemble
import (
"errors"
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/meta"
"github.com/sjwhitworth/golearn/trees"
)
// RandomForest classifies instances using an ensemble
// of bagged random decision trees.
type RandomForest struct {
base.BaseClassifier
ForestSize int
Features int
Model *meta.BaggedModel
}
// NewRandomForest generates and return a new random forests
// forestSize controls the number of trees that get built
// features controls the number of features used to build each tree.
func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{
base.BaseClassifier{},
forestSize,
features,
nil,
}
return ret
}
// Fit builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on base.FixedDataGrid) error {
numNonClassAttributes := len(base.NonClassAttributes(on))
if numNonClassAttributes < f.Features {
return errors.New(fmt.Sprintf(
"Random forest with %d features cannot fit data grid with %d non-class attributes",
f.Features,
numNonClassAttributes,
))
}
f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features
for i := 0; i < f.ForestSize; i++ {
tree := trees.NewID3DecisionTree(0.00)
f.Model.AddModel(tree)
}
f.Model.Fit(on)
return nil
}
// Predict generates predictions from a trained RandomForest.
func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, error) {
return f.Model.Predict(with)
}
// String returns a human-readable representation of this tree.
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}
func (f *RandomForest) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (f *RandomForest) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, f.GetMetadata())
if err != nil {
return err
}
err = f.SaveWithPrefix(writer, "model")
writer.Close()
return err
}
func (f *RandomForest) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
return f.Model.SaveWithPrefix(writer, prefix)
}
func (f *RandomForest) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
return f.LoadWithPrefix(reader, "model")
}
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
f.Model = new(meta.BaggedModel)
for i := 0; i < f.ForestSize; i++ {
tree := trees.NewID3DecisionTree(0.00)
f.Model.AddModel(tree)
}
return f.Model.LoadWithPrefix(reader, prefix)
}