diff --git a/ensemble/ensemble.go b/ensemble/ensemble.go index a4bcd77..7a8e052 100644 --- a/ensemble/ensemble.go +++ b/ensemble/ensemble.go @@ -1,13 +1,12 @@ -/* +// +// +// Ensemble contains classifiers which combine other classifiers. +// +// RandomForest: +// Generates ForestSize bagged decision trees (currently ID3-based) +// each considering a fixed number of random features. +// +// Built on meta.Bagging +// - Ensemble contains classifiers which combine other classifiers. - - RandomForest: - Generates ForestSize bagged decision trees (currently ID3-based) - each considering a fixed number of random features. - - Built on meta.Bagging - -*/ - -package ensemble \ No newline at end of file +package ensemble diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index e00129a..cb3c374 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -8,7 +8,7 @@ import ( ) // RandomForest classifies instances using an ensemble -// of bagged random decision trees +// of bagged random decision trees. type RandomForest struct { base.BaseClassifier ForestSize int @@ -18,7 +18,7 @@ type RandomForest struct { // NewRandomForest generates and return a new random forests // forestSize controls the number of trees that get built -// features controls the number of features used to build each tree +// features controls the number of features used to build each tree. func NewRandomForest(forestSize int, features int) *RandomForest { ret := &RandomForest{ base.BaseClassifier{}, @@ -30,7 +30,7 @@ func NewRandomForest(forestSize int, features int) *RandomForest { } // Fit builds the RandomForest on the specified instances -func (f *RandomForest) Fit(on *base.Instances) { +func (f *RandomForest) Fit(on base.FixedDataGrid) { f.Model = new(meta.BaggedModel) f.Model.RandomFeatures = f.Features for i := 0; i < f.ForestSize; i++ { @@ -40,11 +40,12 @@ func (f *RandomForest) Fit(on *base.Instances) { f.Model.Fit(on) } -// Predict generates predictions from a trained RandomForest -func (f *RandomForest) Predict(with *base.Instances) *base.Instances { +// Predict generates predictions from a trained RandomForest. +func (f *RandomForest) Predict(with base.FixedDataGrid) base.FixedDataGrid { return f.Model.Predict(with) } +// String returns a human-readable representation of this tree. func (f *RandomForest) String() string { return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model) } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index ce33d10..1524cc1 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -13,12 +13,16 @@ func TestRandomForest1(testEnv *testing.T) { if err != nil { panic(err) } - trainData, testData := base.InstancesTrainTestSplit(inst, 0.60) - filt := filters.NewChiMergeFilter(trainData, 0.90) - filt.AddAllNumericAttributes() - filt.Build() - filt.Run(testData) - filt.Run(trainData) + + filt := filters.NewChiMergeFilter(inst, 0.90) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + instf := base.NewLazilyFilteredInstances(inst, filt) + + trainData, testData := base.InstancesTrainTestSplit(instf, 0.60) + rf := NewRandomForest(10, 3) rf.Fit(trainData) predictions := rf.Predict(testData)