mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Identified source of the low accuracy
This commit is contained in:
parent
13c0dc3eba
commit
12ace9def5
@ -13,13 +13,13 @@ func TestRandomForest1(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.80)
|
||||
filt := filters.NewChiMergeFilter(insts[0], 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(insts[1])
|
||||
filt.Run(insts[0])
|
||||
rf := NewRandomForest(15, 2)
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(insts[0])
|
||||
predictions := rf.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
|
@ -55,7 +55,6 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
|
||||
}
|
||||
|
||||
// Pick the one which maximises IG
|
||||
|
||||
return f.GetAttr(selectedAttribute)
|
||||
}
|
||||
|
||||
|
34
trees/id3.go
34
trees/id3.go
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
||||
@ -66,9 +67,9 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no more attribute left to split on,
|
||||
// If there are no more Attributes left to split on,
|
||||
// return a DecisionTreeLeaf with the majority class
|
||||
if from.GetAttributeCount() == 1 {
|
||||
if from.GetAttributeCount() == 2 {
|
||||
ret := &DecisionTreeNode{
|
||||
LeafNode,
|
||||
nil,
|
||||
@ -82,7 +83,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
|
||||
ret := &DecisionTreeNode{
|
||||
RuleNode,
|
||||
make(map[string]*DecisionTreeNode),
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
@ -92,9 +93,14 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
// Generate a return structure
|
||||
// Generate the splitting attribute
|
||||
splitOnAttribute := with.GenerateSplitAttribute(from)
|
||||
if splitOnAttribute == nil {
|
||||
// Can't determine, just return what we have
|
||||
return ret
|
||||
}
|
||||
// Split the attributes based on this attribute's value
|
||||
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
|
||||
// Create new children from these attributes
|
||||
ret.Children = make(map[string]*DecisionTreeNode)
|
||||
for k := range splitInstances {
|
||||
newInstances := splitInstances[k]
|
||||
ret.Children[k] = InferID3Tree(newInstances, with)
|
||||
@ -114,7 +120,12 @@ func (d *DecisionTreeNode) getNestedString(level int) string {
|
||||
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
|
||||
keys := make([]string, 0)
|
||||
for k := range d.Children {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(tmp.String())
|
||||
buf.WriteString("\t")
|
||||
@ -171,16 +182,25 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
predictions := base.NewInstances(outputAttrs, what.Rows)
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
cur := d
|
||||
for j := 0; j < what.Cols; j++ {
|
||||
at := what.GetAttr(j)
|
||||
classVar := at.GetStringFromSysVal(what.Get(i, j))
|
||||
for {
|
||||
if cur.Children == nil {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
break
|
||||
} else {
|
||||
at := cur.SplitAttr
|
||||
j := what.GetAttrIndex(at)
|
||||
classVar := at.GetStringFromSysVal(what.Get(i, j))
|
||||
if next, ok := cur.Children[classVar]; ok {
|
||||
cur = next
|
||||
} else {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
var bestChild string
|
||||
for c := range cur.Children {
|
||||
bestChild = c
|
||||
if c > classVar {
|
||||
break
|
||||
}
|
||||
}
|
||||
cur = cur.Children[bestChild]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,11 +20,22 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base
|
||||
consideredAttributes := make([]int, r.Attributes)
|
||||
attrCounter := 0
|
||||
for {
|
||||
if attrCounter >= r.Attributes {
|
||||
if len(consideredAttributes) >= r.Attributes {
|
||||
break
|
||||
}
|
||||
selectedAttribute := rand.Intn(maximumAttribute)
|
||||
fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
|
||||
if selectedAttribute != f.ClassIndex {
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a == selectedAttribute {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
@ -170,3 +170,27 @@ func TestID3Inference(testEnv *testing.T) {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
}
|
||||
|
||||
func TestID3Classification(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
insts := base.InstancesTrainTestSplit(inst, 0.70)
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(insts[0], rule)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(insts[1])
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user