diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index 3776ebe..a27cca0 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -29,15 +29,15 @@ func NewRandomForest(forestSize int, features int) RandomForest { } // Train builds the RandomForest on the specified instances -func (f *RandomForest) Train(on *base.Instances) { +func (f *RandomForest) Fit(on *base.Instances) { f.Model = new(meta.BaggedModel) for i := 0; i < f.ForestSize; i++ { tree := new(trees.RandomTree) - tree.Rules = new(trees.RandomTreeRule) - tree.Attributes = f.Features + tree.Rule = new(trees.RandomTreeRuleGenerator) + tree.Rule.Attributes = f.Features f.Model.AddModel(tree) } - f.Model.Train(on) + f.Model.Fit(on) } // Predict generates predictions from a trained RandomForest diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 5ecdd04..b8dac72 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -20,7 +20,7 @@ func TestRandomForest1(testEnv *testing.T) { filt.Run(insts[1]) filt.Run(insts[0]) rf := NewRandomForest(10, 2) - rf.Train(insts[0]) + rf.Fit(insts[0]) predictions := rf.Predict(insts[1]) fmt.Println(predictions) confusionMat := eval.GetConfusionMatrix(insts[1], predictions) diff --git a/meta/bagging.go b/meta/bagging.go index 5a1a5be..23cb6aa 100644 --- a/meta/bagging.go +++ b/meta/bagging.go @@ -15,10 +15,10 @@ type BaggedModel struct { } func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances { - from = from.SampleWithReplacement(from.Rows) - return from + return from.SampleWithReplacement(from.Rows) } +// AddModel adds a base.Classifier to the current model func (b *BaggedModel) AddModel(m base.Classifier) { b.Models = append(b.Models, m) } diff --git a/trees/random.go b/trees/random.go index 0ab9539..0c8c2bf 100644 --- a/trees/random.go +++ b/trees/random.go @@ -36,14 +36,14 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base type RandomTree struct { base.BaseClassifier Root *DecisionTreeNode - Rule RandomTreeRuleGenerator + Rule *RandomTreeRuleGenerator } func NewRandomTree(attrs int) *RandomTree { return &RandomTree{ base.BaseClassifier{}, nil, - RandomTreeRuleGenerator{ + &RandomTreeRuleGenerator{ attrs, InformationGainRuleGenerator{}, }, @@ -52,7 +52,7 @@ func NewRandomTree(attrs int) *RandomTree { // Train builds a RandomTree suitable for prediction func (rt *RandomTree) Fit(from *base.Instances) { - rt.Root = InferID3Tree(from, &rt.Rule) + rt.Root = InferID3Tree(from, rt.Rule) } // Predict returns a set of Instances containing predictions