1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/trees/random.go

89 lines
2.4 KiB
Go
Raw Normal View History

package trees
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math/rand"
)
2014-05-19 12:59:11 +01:00
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
type RandomTreeRuleGenerator struct {
2014-05-17 17:28:51 +01:00
Attributes int
internalRule InformationGainRuleGenerator
}
2014-05-19 12:59:11 +01:00
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
// First step is to generate the random attributes that we'll consider
maximumAttribute := f.GetAttributeCount()
consideredAttributes := make([]int, r.Attributes)
attrCounter := 0
for {
2014-05-17 20:37:19 +01:00
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttribute := rand.Intn(maximumAttribute)
2014-05-17 20:37:19 +01:00
fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
if selectedAttribute != f.ClassIndex {
2014-05-17 20:37:19 +01:00
matched := false
for _, a := range consideredAttributes {
if a == selectedAttribute {
matched = true
break
}
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
}
2014-05-17 17:28:51 +01:00
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
}
2014-05-19 12:59:11 +01:00
// RandomTree builds a decision tree by considering a fixed number
// of randomly-chosen attributes at each node
type RandomTree struct {
base.BaseClassifier
Root *DecisionTreeNode
2014-05-17 17:35:10 +01:00
Rule *RandomTreeRuleGenerator
}
2014-05-19 12:59:11 +01:00
// NewRandomTree returns a new RandomTree which considers attrs randomly
// chosen attributes at each node.
func NewRandomTree(attrs int) *RandomTree {
return &RandomTree{
base.BaseClassifier{},
nil,
2014-05-17 17:35:10 +01:00
&RandomTreeRuleGenerator{
attrs,
2014-05-17 17:28:51 +01:00
InformationGainRuleGenerator{},
},
}
}
// Train builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
2014-05-17 17:35:10 +01:00
rt.Root = InferID3Tree(from, rt.Rule)
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
return rt.Root.Predict(from)
}
2014-05-19 12:59:11 +01:00
// String returns a human-readable representation of this structure
func (rt *RandomTree) String() string {
return fmt.Sprintf("RandomTree(%s)", rt.Root)
}
2014-05-17 18:06:01 +01:00
2014-05-19 12:59:11 +01:00
// Prune removes nodes from the tree which are detrimental
// to determining the accuracy of the test set (with)
2014-05-17 18:06:01 +01:00
func (rt *RandomTree) Prune(with *base.Instances) {
rt.Root.Prune(with)
}