1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Identified source of the low accuracy

This commit is contained in:
Richard Townsend 2014-05-17 20:37:19 +01:00
parent 13c0dc3eba
commit 12ace9def5
5 changed files with 66 additions and 12 deletions

View File

@ -13,13 +13,13 @@ func TestRandomForest1(testEnv *testing.T) {
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.6)
insts := base.InstancesTrainTestSplit(inst, 0.80)
filt := filters.NewChiMergeFilter(insts[0], 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
rf := NewRandomForest(15, 2)
rf := NewRandomForest(10, 3)
rf.Fit(insts[0])
predictions := rf.Predict(insts[1])
fmt.Println(predictions)

View File

@ -55,7 +55,6 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
}
// Pick the one which maximises IG
return f.GetAttr(selectedAttribute)
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
"sort"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
@ -66,9 +67,9 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
}
}
// If there are no more attribute left to split on,
// If there are no more Attributes left to split on,
// return a DecisionTreeLeaf with the majority class
if from.GetAttributeCount() == 1 {
if from.GetAttributeCount() == 2 {
ret := &DecisionTreeNode{
LeafNode,
nil,
@ -82,7 +83,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
ret := &DecisionTreeNode{
RuleNode,
make(map[string]*DecisionTreeNode),
nil,
nil,
classes,
maxClass,
@ -92,9 +93,14 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// Generate a return structure
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
if splitOnAttribute == nil {
// Can't determine, just return what we have
return ret
}
// Split the attributes based on this attribute's value
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferID3Tree(newInstances, with)
@ -114,7 +120,12 @@ func (d *DecisionTreeNode) getNestedString(level int) string {
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
} else {
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
keys := make([]string, 0)
for k := range d.Children {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
buf.WriteString("\n")
buf.WriteString(tmp.String())
buf.WriteString("\t")
@ -171,16 +182,25 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
predictions := base.NewInstances(outputAttrs, what.Rows)
for i := 0; i < what.Rows; i++ {
cur := d
for j := 0; j < what.Cols; j++ {
at := what.GetAttr(j)
classVar := at.GetStringFromSysVal(what.Get(i, j))
for {
if cur.Children == nil {
predictions.SetAttrStr(i, 0, cur.Class)
break
} else {
at := cur.SplitAttr
j := what.GetAttrIndex(at)
classVar := at.GetStringFromSysVal(what.Get(i, j))
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
predictions.SetAttrStr(i, 0, cur.Class)
var bestChild string
for c := range cur.Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur.Children[bestChild]
}
}
}

View File

@ -20,11 +20,22 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
consideredAttributes := make([]int, r.Attributes)
attrCounter := 0
for {
if attrCounter >= r.Attributes {
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttribute := rand.Intn(maximumAttribute)
fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
if selectedAttribute != f.ClassIndex {
matched := false
for _, a := range consideredAttributes {
if a == selectedAttribute {
matched = true
break
}
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}

View File

@ -56,7 +56,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.6)
insts := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
@ -170,3 +170,27 @@ func TestID3Inference(testEnv *testing.T) {
testEnv.Error(overcastChild)
}
}
func TestID3Classification(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewBinningFilter(inst, 10)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
insts := base.InstancesTrainTestSplit(inst, 0.70)
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(insts[0], rule)
fmt.Println(root)
predictions := root.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}