mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Passes all the tests
This commit is contained in:
parent
db3ac3c695
commit
c516907b13
@ -29,15 +29,15 @@ func NewRandomForest(forestSize int, features int) RandomForest {
|
||||
}
|
||||
|
||||
// Train builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Train(on *base.Instances) {
|
||||
func (f *RandomForest) Fit(on *base.Instances) {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
tree := new(trees.RandomTree)
|
||||
tree.Rules = new(trees.RandomTreeRule)
|
||||
tree.Attributes = f.Features
|
||||
tree.Rule = new(trees.RandomTreeRuleGenerator)
|
||||
tree.Rule.Attributes = f.Features
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
f.Model.Train(on)
|
||||
f.Model.Fit(on)
|
||||
}
|
||||
|
||||
// Predict generates predictions from a trained RandomForest
|
||||
|
@ -20,7 +20,7 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
filt.Run(insts[1])
|
||||
filt.Run(insts[0])
|
||||
rf := NewRandomForest(10, 2)
|
||||
rf.Train(insts[0])
|
||||
rf.Fit(insts[0])
|
||||
predictions := rf.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
||||
|
@ -15,10 +15,10 @@ type BaggedModel struct {
|
||||
}
|
||||
|
||||
func (b *BaggedModel) generateTrainingInstances(from *base.Instances) *base.Instances {
|
||||
from = from.SampleWithReplacement(from.Rows)
|
||||
return from
|
||||
return from.SampleWithReplacement(from.Rows)
|
||||
}
|
||||
|
||||
// AddModel adds a base.Classifier to the current model
|
||||
func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
b.Models = append(b.Models, m)
|
||||
}
|
||||
|
@ -36,14 +36,14 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
|
||||
type RandomTree struct {
|
||||
base.BaseClassifier
|
||||
Root *DecisionTreeNode
|
||||
Rule RandomTreeRuleGenerator
|
||||
Rule *RandomTreeRuleGenerator
|
||||
}
|
||||
|
||||
func NewRandomTree(attrs int) *RandomTree {
|
||||
return &RandomTree{
|
||||
base.BaseClassifier{},
|
||||
nil,
|
||||
RandomTreeRuleGenerator{
|
||||
&RandomTreeRuleGenerator{
|
||||
attrs,
|
||||
InformationGainRuleGenerator{},
|
||||
},
|
||||
@ -52,7 +52,7 @@ func NewRandomTree(attrs int) *RandomTree {
|
||||
|
||||
// Train builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||
rt.Root = InferID3Tree(from, &rt.Rule)
|
||||
rt.Root = InferID3Tree(from, rt.Rule)
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
|
Loading…
x
Reference in New Issue
Block a user