1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Passes all the tests

This commit is contained in:
Richard Townsend 2014-05-17 17:35:10 +01:00
parent db3ac3c695
commit c516907b13
4 changed files with 10 additions and 10 deletions

View File

@ -29,15 +29,15 @@ func NewRandomForest(forestSize int, features int) RandomForest {
}
// Train builds the RandomForest on the specified instances
func (f *RandomForest) Train(on *base.Instances) {
func (f *RandomForest) Fit(on *base.Instances) {
f.Model = new(meta.BaggedModel)
for i := 0; i < f.ForestSize; i++ {
tree := new(trees.RandomTree)
tree.Rules = new(trees.RandomTreeRule)
tree.Attributes = f.Features
tree.Rule = new(trees.RandomTreeRuleGenerator)
tree.Rule.Attributes = f.Features
f.Model.AddModel(tree)
}
f.Model.Train(on)
f.Model.Fit(on)
}
// Predict generates predictions from a trained RandomForest

View File

@ -20,7 +20,7 @@ func TestRandomForest1(testEnv *testing.T) {
filt.Run(insts[1])
filt.Run(insts[0])
rf := NewRandomForest(10, 2)
rf.Train(insts[0])
rf.Fit(insts[0])
predictions := rf.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)

View File

@ -15,10 +15,10 @@ type BaggedModel struct {
}
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances {
from = from.SampleWithReplacement(from.Rows)
return from
return from.SampleWithReplacement(from.Rows)
}
// AddModel adds a base.Classifier to the current model
func (b *BaggedModel) AddModel(m base.Classifier) {
b.Models = append(b.Models, m)
}

View File

@ -36,14 +36,14 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
type RandomTree struct {
base.BaseClassifier
Root *DecisionTreeNode
Rule RandomTreeRuleGenerator
Rule *RandomTreeRuleGenerator
}
func NewRandomTree(attrs int) *RandomTree {
return &RandomTree{
base.BaseClassifier{},
nil,
RandomTreeRuleGenerator{
&RandomTreeRuleGenerator{
attrs,
InformationGainRuleGenerator{},
},
@ -52,7 +52,7 @@ func NewRandomTree(attrs int) *RandomTree {
// Train builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
rt.Root = InferID3Tree(from, &rt.Rule)
rt.Root = InferID3Tree(from, rt.Rule)
}
// Predict returns a set of Instances containing predictions