mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Package documentation
This commit is contained in:
parent
889fec4419
commit
a6072ac9de
13
ensemble/ensemble.go
Normal file
13
ensemble/ensemble.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
/*
|
||||||
|
|
||||||
|
Ensemble contains classifiers which combine other classifiers.
|
||||||
|
|
||||||
|
RandomForest:
|
||||||
|
Generates ForestSize bagged decision trees (currently ID3-based)
|
||||||
|
each considering a fixed number of random features.
|
||||||
|
|
||||||
|
Built on meta.Bagging
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ensemble
|
@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaggedModels train Classifiers on subsets of the original
|
// BaggedModel trains base.Classifiers on subsets of the original
|
||||||
// Instances and combine the results through voting
|
// Instances and combine the results through voting
|
||||||
type BaggedModel struct {
|
type BaggedModel struct {
|
||||||
base.BaseClassifier
|
base.BaseClassifier
|
||||||
@ -17,6 +17,8 @@ type BaggedModel struct {
|
|||||||
RandomFeatures int
|
RandomFeatures int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
||||||
|
// the provided base.Instances.
|
||||||
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
||||||
ret := make([]base.Attribute, 0)
|
ret := make([]base.Attribute, 0)
|
||||||
if b.RandomFeatures == 0 {
|
if b.RandomFeatures == 0 {
|
||||||
@ -51,11 +53,17 @@ func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []b
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generatePredictionInstances returns a modified version of the
|
||||||
|
// requested base.Instances with only the base.Attributes selected
|
||||||
|
// for training the model.
|
||||||
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
||||||
selected := b.selectedAttributes[model]
|
selected := b.selectedAttributes[model]
|
||||||
return from.SelectAttributes(selected)
|
return from.SelectAttributes(selected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateTrainingInstances generates RandomFeatures number of
|
||||||
|
// attributes and returns a modified version of base.Instances
|
||||||
|
// for training the model
|
||||||
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
||||||
insts := from.SampleWithReplacement(from.Rows)
|
insts := from.SampleWithReplacement(from.Rows)
|
||||||
selected := b.generateTrainingAttrs(model, from)
|
selected := b.generateTrainingAttrs(model, from)
|
||||||
|
@ -109,6 +109,8 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getNestedString returns the contents of node d
|
||||||
|
// prefixed by level number of tags (also prints children)
|
||||||
func (d *DecisionTreeNode) getNestedString(level int) string {
|
func (d *DecisionTreeNode) getNestedString(level int) string {
|
||||||
buf := bytes.NewBuffer(nil)
|
buf := bytes.NewBuffer(nil)
|
||||||
tmp := bytes.NewBuffer(nil)
|
tmp := bytes.NewBuffer(nil)
|
||||||
@ -143,6 +145,7 @@ func (d *DecisionTreeNode) String() string {
|
|||||||
return d.getNestedString(0)
|
return d.getNestedString(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// computeAccuracy is a helper method for Prune()
|
||||||
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
||||||
cf := eval.GetConfusionMatrix(from, predictions)
|
cf := eval.GetConfusionMatrix(from, predictions)
|
||||||
return eval.GetAccuracy(cf)
|
return eval.GetAccuracy(cf)
|
||||||
@ -231,6 +234,8 @@ type ID3DecisionTree struct {
|
|||||||
PruneSplit float64
|
PruneSplit float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a new ID3DecisionTree with the specified test-prune
|
||||||
|
// ratio. Of the ratio is less than 0.001, the tree isn't pruned
|
||||||
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
||||||
return &ID3DecisionTree{
|
return &ID3DecisionTree{
|
||||||
base.BaseClassifier{},
|
base.BaseClassifier{},
|
||||||
@ -256,7 +261,7 @@ func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
|
|||||||
return t.Root.Predict(what)
|
return t.Root.Predict(what)
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a human-readable ID3 tree
|
// String returns a human-readable version of this ID3 tree
|
||||||
func (t *ID3DecisionTree) String() string {
|
func (t *ID3DecisionTree) String() string {
|
||||||
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
|
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
|
||||||
}
|
}
|
||||||
|
@ -6,13 +6,14 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
|
||||||
type RandomTreeRuleGenerator struct {
|
type RandomTreeRuleGenerator struct {
|
||||||
Attributes int
|
Attributes int
|
||||||
internalRule InformationGainRuleGenerator
|
internalRule InformationGainRuleGenerator
|
||||||
}
|
}
|
||||||
|
|
||||||
// So WEKA returns a couple of possible attributes and evaluates
|
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
|
||||||
// the split criteria on each
|
// which maximises Information Gain
|
||||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||||
|
|
||||||
// First step is to generate the random attributes that we'll consider
|
// First step is to generate the random attributes that we'll consider
|
||||||
@ -44,12 +45,16 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
|
|||||||
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandomTree builds a decision tree by considering a fixed number
|
||||||
|
// of randomly-chosen attributes at each node
|
||||||
type RandomTree struct {
|
type RandomTree struct {
|
||||||
base.BaseClassifier
|
base.BaseClassifier
|
||||||
Root *DecisionTreeNode
|
Root *DecisionTreeNode
|
||||||
Rule *RandomTreeRuleGenerator
|
Rule *RandomTreeRuleGenerator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRandomTree returns a new RandomTree which considers attrs randomly
|
||||||
|
// chosen attributes at each node.
|
||||||
func NewRandomTree(attrs int) *RandomTree {
|
func NewRandomTree(attrs int) *RandomTree {
|
||||||
return &RandomTree{
|
return &RandomTree{
|
||||||
base.BaseClassifier{},
|
base.BaseClassifier{},
|
||||||
@ -71,10 +76,13 @@ func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
|
|||||||
return rt.Root.Predict(from)
|
return rt.Root.Predict(from)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a human-readable representation of this structure
|
||||||
func (rt *RandomTree) String() string {
|
func (rt *RandomTree) String() string {
|
||||||
return fmt.Sprintf("RandomTree(%s)", rt.Root)
|
return fmt.Sprintf("RandomTree(%s)", rt.Root)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prune removes nodes from the tree which are detrimental
|
||||||
|
// to determining the accuracy of the test set (with)
|
||||||
func (rt *RandomTree) Prune(with *base.Instances) {
|
func (rt *RandomTree) Prune(with *base.Instances) {
|
||||||
rt.Root.Prune(with)
|
rt.Root.Prune(with)
|
||||||
}
|
}
|
||||||
|
@ -1,2 +1,26 @@
|
|||||||
// Package trees provides a number of tree based ensemble learners.
|
/*
|
||||||
|
|
||||||
|
This package implements decision trees.
|
||||||
|
|
||||||
|
ID3DecisionTree:
|
||||||
|
Builds a decision tree using the ID3 algorithm
|
||||||
|
by picking the Attribute which maximises
|
||||||
|
Information Gain at each node.
|
||||||
|
|
||||||
|
Attributes must be CategoricalAttributes at
|
||||||
|
present, so discretise beforehand (see
|
||||||
|
filters)
|
||||||
|
|
||||||
|
RandomTree:
|
||||||
|
Builds a decision tree using the ID3 algorithm
|
||||||
|
by picking the Attribute amongst those
|
||||||
|
randomly selected that maximises Information
|
||||||
|
Gain
|
||||||
|
|
||||||
|
Attributes must be CategoricalAttributes at
|
||||||
|
present, so discretise beforehand (see
|
||||||
|
filters)
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
package trees
|
package trees
|
Loading…
x
Reference in New Issue
Block a user