mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
133 lines
3.6 KiB
Go
133 lines
3.6 KiB
Go
package trees
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"math/rand"
|
|
)
|
|
|
|
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
|
|
type RandomTreeRuleGenerator struct {
|
|
Attributes int
|
|
internalRule InformationGainRuleGenerator
|
|
}
|
|
|
|
// GenerateSplitRule returns the best attribute out of those randomly chosen
|
|
// which maximises Information Gain
|
|
func (r *RandomTreeRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
|
|
|
|
var consideredAttributes []base.Attribute
|
|
|
|
// First step is to generate the random attributes that we'll consider
|
|
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
|
|
maximumAttribute := len(allAttributes)
|
|
|
|
attrCounter := 0
|
|
for {
|
|
if len(consideredAttributes) >= r.Attributes {
|
|
break
|
|
}
|
|
selectedAttrIndex := rand.Intn(maximumAttribute)
|
|
selectedAttribute := allAttributes[selectedAttrIndex]
|
|
matched := false
|
|
for _, a := range consideredAttributes {
|
|
if a.Equals(selectedAttribute) {
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if matched {
|
|
continue
|
|
}
|
|
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
|
attrCounter++
|
|
}
|
|
|
|
return r.internalRule.GetSplitRuleFromSelection(consideredAttributes, f)
|
|
}
|
|
|
|
// RandomTree builds a decision tree by considering a fixed number
|
|
// of randomly-chosen attributes at each node
|
|
type RandomTree struct {
|
|
base.BaseClassifier
|
|
Root *DecisionTreeNode
|
|
Rule *RandomTreeRuleGenerator
|
|
}
|
|
|
|
// NewRandomTree returns a new RandomTree which considers attrs randomly
|
|
// chosen attributes at each node.
|
|
func NewRandomTree(attrs int) *RandomTree {
|
|
return &RandomTree{
|
|
base.BaseClassifier{},
|
|
nil,
|
|
&RandomTreeRuleGenerator{
|
|
attrs,
|
|
InformationGainRuleGenerator{},
|
|
},
|
|
}
|
|
}
|
|
|
|
// Fit builds a RandomTree suitable for prediction
|
|
func (rt *RandomTree) Fit(from base.FixedDataGrid) error {
|
|
rt.Root = InferID3Tree(from, rt.Rule)
|
|
return nil
|
|
}
|
|
|
|
// Predict returns a set of Instances containing predictions
|
|
func (rt *RandomTree) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
return rt.Root.Predict(from)
|
|
}
|
|
|
|
// String returns a human-readable representation of this structure
|
|
func (rt *RandomTree) String() string {
|
|
return fmt.Sprintf("RandomTree(%s)", rt.Root)
|
|
}
|
|
|
|
// Prune removes nodes from the tree which are detrimental
|
|
// to determining the accuracy of the test set (with)
|
|
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
|
|
rt.Root.Prune(with)
|
|
}
|
|
|
|
// Save outputs this model to a file
|
|
func (rt *RandomTree) Save(filePath string) error {
|
|
writer, err := base.CreateSerializedClassifierStub(filePath, rt.GetMetadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
writer.Close()
|
|
}()
|
|
return rt.SaveWithPrefix(writer, "")
|
|
}
|
|
|
|
// SaveWithPrefix outputs this model to a file with a prefix.
|
|
func (rt *RandomTree) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
|
return rt.Root.SaveWithPrefix(writer, prefix)
|
|
}
|
|
|
|
// Load retrieves this model from a file
|
|
func (rt *RandomTree) Load(filePath string) error {
|
|
reader, err := base.ReadSerializedClassifierStub(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rt.LoadWithPrefix(reader, "")
|
|
}
|
|
|
|
// LoadWithPrefix retrives this random tree from disk with a given prefix.
|
|
func (rt *RandomTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
|
rt.Root = &DecisionTreeNode{}
|
|
return rt.Root.LoadWithPrefix(reader, prefix)
|
|
}
|
|
|
|
// GetMetadata returns required serialization metadata
|
|
func (rt *RandomTree) GetMetadata() base.ClassifierMetadataV1 {
|
|
return base.ClassifierMetadataV1{
|
|
FormatVersion: 1,
|
|
ClassifierName: "KNN",
|
|
ClassifierVersion: "1.0",
|
|
ClassifierMetadata: nil,
|
|
}
|
|
}
|