1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

trees: merge from v2-instances

This commit is contained in:
Richard Townsend 2014-08-02 16:22:15 +01:00
parent 3b3b23a221
commit c2d040af30
4 changed files with 132 additions and 95 deletions

View File

@ -17,35 +17,40 @@ type InformationGainRuleGenerator struct {
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
allAttributes := make([]int, 0)
for i := 0; i < f.Cols; i++ {
if i != f.ClassIndex {
allAttributes = append(allAttributes, i)
}
}
return r.GetSplitAttributeFromSelection(allAttributes, f)
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitAttributeFromSelection(candidates, f)
}
// GetSplitAttributeFromSelection returns the class Attribute which maximises
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
var selectedAttribute base.Attribute
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedAttribute := -1
// Compute the base entropy
classDist := f.GetClassDistribution()
classDist := base.GetClassDistribution(f)
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
@ -55,7 +60,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
}
// Pick the one which maximises IG
return f.GetAttr(selectedAttribute)
return selectedAttribute
}
//

View File

@ -21,7 +21,7 @@ const (
// RuleGenerator implementations analyse instances and determine
// the best value to split on
type RuleGenerator interface {
GenerateSplitAttribute(*base.Instances) base.Attribute
GenerateSplitAttribute(base.FixedDataGrid) base.Attribute
}
// DecisionTreeNode represents a given portion of a decision tree
@ -31,14 +31,19 @@ type DecisionTreeNode struct {
SplitAttr base.Attribute
ClassDist map[string]int
Class string
ClassAttr *base.Attribute
ClassAttr base.Attribute
}
func getClassAttr(from base.FixedDataGrid) base.Attribute {
allClassAttrs := from.AllClassAttributes()
return allClassAttrs[0]
}
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := from.CountClassValues()
classes := base.GetClassDistribution(from)
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len(classes) == 1 {
@ -52,7 +57,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
return ret
}
@ -69,28 +74,29 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// If there are no more Attributes left to split on,
// return a DecisionTreeLeaf with the majority class
if from.GetAttributeCount() == 2 {
cols, _ := from.Size()
if cols == 2 {
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
return ret
}
// Generate a return structure
ret := &DecisionTreeNode{
RuleNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
getClassAttr(from),
}
// Generate a return structure
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
if splitOnAttribute == nil {
@ -98,7 +104,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
return ret
}
// Split the attributes based on this attribute's value
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute)
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
@ -146,13 +152,13 @@ func (d *DecisionTreeNode) String() string {
}
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
cf := eval.GetConfusionMatrix(from, predictions)
return eval.GetAccuracy(cf)
}
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using *base.Instances) {
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
// If you're a leaf, you're already pruned
if d.Children == nil {
return
@ -162,11 +168,15 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
}
// Recursively prune children of this node
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
for k := range d.Children {
if sub[k] == nil {
continue
}
subH, subV := sub[k].Size()
if subH == 0 || subV == 0 {
continue
}
d.Children[k].Prune(sub[k])
}
@ -185,24 +195,30 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
outputAttrs := make([]base.Attribute, 1)
outputAttrs[0] = what.GetClassAttr()
predictions := base.NewInstances(outputAttrs, what.Rows)
for i := 0; i < what.Rows; i++ {
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
predictions := base.GeneratePredictionVector(what)
classAttr := getClassAttr(predictions)
classAttrSpec, err := predictions.GetAttribute(classAttr)
if err != nil {
panic(err)
}
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAllAttributes(what, predAttrs)
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d
for {
if cur.Children == nil {
predictions.SetAttrStr(i, 0, cur.Class)
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
} else {
at := cur.SplitAttr
j := what.GetAttrIndex(at)
if j == -1 {
predictions.SetAttrStr(i, 0, cur.Class)
ats, err := what.GetAttribute(at)
if err != nil {
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
}
classVar := at.GetStringFromSysVal(what.Get(i, j))
classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
@ -217,7 +233,8 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
}
}
}
}
return true, nil
})
return predictions
}
@ -245,7 +262,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on *base.Instances) {
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
@ -257,7 +274,7 @@ func (t *ID3DecisionTree) Fit(on *base.Instances) {
}
// Predict outputs predictions from the ID3 decision tree
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
return t.Root.Predict(what)
}

View File

@ -14,32 +14,32 @@ type RandomTreeRuleGenerator struct {
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
// First step is to generate the random attributes that we'll consider
maximumAttribute := f.GetAttributeCount()
consideredAttributes := make([]int, r.Attributes)
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
maximumAttribute := len(allAttributes)
consideredAttributes := make([]base.Attribute, 0)
attrCounter := 0
for {
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttribute := rand.Intn(maximumAttribute)
base.Logger.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
if selectedAttribute != f.ClassIndex {
matched := false
for _, a := range consideredAttributes {
if a == selectedAttribute {
matched = true
break
}
selectedAttrIndex := rand.Intn(maximumAttribute)
selectedAttribute := allAttributes[selectedAttrIndex]
matched := false
for _, a := range consideredAttributes {
if a.Equals(selectedAttribute) {
matched = true
break
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
@ -67,12 +67,12 @@ func NewRandomTree(attrs int) *RandomTree {
}
// Fit builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
rt.Root = InferID3Tree(from, rt.Rule)
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
return rt.Root.Predict(from)
}
@ -83,6 +83,6 @@ func (rt *RandomTree) String() string {
// Prune removes nodes from the tree which are detrimental
// to determining the accuracy of the test set (with)
func (rt *RandomTree) Prune(with *base.Instances) {
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
rt.Root.Prune(with)
}

View File

@ -14,15 +14,17 @@ func TestRandomTree(testEnv *testing.T) {
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(inst, r)
fmt.Println(instf)
root := InferID3Tree(instf, r)
fmt.Println(root)
}
@ -33,18 +35,20 @@ func TestRandomTreeClassification(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(trainData)
filt.Run(testData)
fmt.Println(inst)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(trainData, r)
root := InferID3Tree(trainDataF, r)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -58,17 +62,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
root.Fit(trainData)
root.Fit(trainDataF)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -82,19 +88,21 @@ func TestPruning(testEnv *testing.T) {
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
fmt.Println(root)
predictions := root.Predict(testData)
predictions := root.Predict(testDataF)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
@ -142,6 +150,7 @@ func TestID3Inference(testEnv *testing.T) {
testEnv.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
fmt.Println(rainyChild.SplitAttr)
testEnv.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
@ -156,7 +165,6 @@ func TestID3Inference(testEnv *testing.T) {
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
@ -176,12 +184,18 @@ func TestID3Classification(testEnv *testing.T) {
if err != nil {
panic(err)
}
filt := filters.NewBinningFilter(inst, 10)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
filt := filters.NewBinningFilter(inst, 10)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
fmt.Println(filt)
instf := base.NewLazilyFilteredInstances(inst, filt)
fmt.Println("INSTFA", instf.AllAttributes())
fmt.Println("INSTF", instf)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, rule)
@ -199,6 +213,7 @@ func TestID3(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
fmt.Println(inst)
if err != nil {
panic(err)
}