1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #91 from Sentimentron/numeric-staging

trees: Handling FloatAttributes.
This commit is contained in:
Stephen Whitworth 2014-10-27 07:56:19 +00:00
commit 056ccef9b6
10 changed files with 668 additions and 121 deletions

View File

@ -78,6 +78,17 @@ func SetClass(at UpdatableDataGrid, row int, class string) {
at.Set(classAttrSpec, row, classBytes)
}
// GetAttributeByName returns an Attribute matching a given name.
// Returns nil if one doesn't exist.
func GetAttributeByName(inst FixedDataGrid, name string) Attribute {
for _, a := range inst.AllAttributes() {
if a.GetName() == name {
return a
}
}
return nil
}
// GetClassDistribution returns a map containing the count of each
// class type (indexed by the class' string representation).
func GetClassDistribution(inst FixedDataGrid) map[string]int {
@ -90,6 +101,42 @@ func GetClassDistribution(inst FixedDataGrid) map[string]int {
return ret
}
// GetClassDistributionAfterThreshold returns the class distribution
// after a speculative split on a given Attribute using a threshold.
func GetClassDistributionAfterThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err))
}
// Validate
if _, ok := at.(*FloatAttribute); !ok {
panic(fmt.Sprintf("Must be numeric!"))
}
_, rows := inst.Size()
for i := 0; i < rows; i++ {
splitVal := UnpackBytesToFloat(inst.Get(attrSpec, i)) > val
splitVar := "0"
if splitVal {
splitVar = "1"
}
classVar := GetClass(inst, i)
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// GetClassDistributionAfterSplit returns the class distribution
// after a speculative split on a given Attribute.
func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string]map[string]int {
@ -118,6 +165,64 @@ func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string
return ret
}
// DecomposeOnNumericAttributeThreshold divides the instance set depending on the
// value of a given numeric Attribute, constructs child instances, and returns
// them in a map keyed on whether that row had a higher value than the threshold
// or not.
//
// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined, or if
// the Attribute is not numeric.
func DecomposeOnNumericAttributeThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]FixedDataGrid {
// Verify
if _, ok := at.(*FloatAttribute); !ok {
panic("Invalid argument")
}
// Find the Attribute we're decomposing on
attrSpec, err := inst.GetAttribute(at)
if err != nil {
panic(fmt.Sprintf("Invalid Attribute index %s", at))
}
// Construct the new Attribute set
newAttrs := make([]Attribute, 0)
for _, a := range inst.AllAttributes() {
if a.Equals(at) {
continue
}
newAttrs = append(newAttrs, a)
}
// Create the return map
ret := make(map[string]FixedDataGrid)
// Create the return row mapping
rowMaps := make(map[string][]int)
// Build full Attribute set
fullAttrSpec := ResolveAttributes(inst, newAttrs)
fullAttrSpec = append(fullAttrSpec, attrSpec)
// Decompose
inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) {
// Find the output instance set
targetBytes := row[len(row)-1]
targetVal := UnpackBytesToFloat(targetBytes)
val := targetVal > val
targetSet := "0"
if val {
targetSet = "1"
}
rowMap := rowMaps[targetSet]
rowMaps[targetSet] = append(rowMap, rowNo)
return true, nil
})
for a := range rowMaps {
ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs)
}
return ret
}
// DecomposeOnAttributeValues divides the instance set depending on the
// value of a given Attribute, constructs child instances, and returns
// them in a map keyed on the string value of that Attribute.

View File

@ -0,0 +1,15 @@
Attribute1,Attribute2,Attribute3,Class
A,70,T,A
A,90,T,B
A,85,F,B
A,95,F,B
A,70,F,A
B,90,T,A
B,78,F,A
B,65,T,A
B,75,F,A
C,80,T,B
C,70,T,B
C,80,F,A
C,80,F,A
C,96,F,A
1 Attribute1 Attribute2 Attribute3 Class
2 A 70 T A
3 A 90 T B
4 A 85 F B
5 A 95 F B
6 A 70 F A
7 B 90 T A
8 B 78 F A
9 B 65 T A
10 B 75 F A
11 C 80 T B
12 C 70 T B
13 C 80 F A
14 C 80 F A
15 C 96 F A

View File

@ -0,0 +1,4 @@
c45-numeric.csv: www.mgt.ncu.edu.tw/~wabble/School/C45.ppt
tennis.csv: "Machine Learning", Tom Mitchell, McGraw-Hill, 1997 (http://books.google.co.uk/books?id=xOGAngEACAAJ&dq=machine+learning,+mitchell&hl=en&sa=X&ei=zvpMVPz8IseN7Aa454DYBg&ved=0CFYQ6AEwBw)

View File

@ -10,14 +10,13 @@ import (
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/trees"
"math/rand"
"time"
)
func main() {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
@ -26,7 +25,7 @@ func main() {
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
@ -55,13 +54,58 @@ func main() {
}
// Evaluate
fmt.Println("ID3 Performance")
fmt.Println("ID3 Performance (information gain)")
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain ratio)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (gini index generator)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Next up, Random Trees
//
@ -86,7 +130,7 @@ func main() {
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree = ensemble.NewRandomForest(70, 3)
err = tree.Fit(trainData)
if err != nil {
panic(err)

View File

@ -3,34 +3,37 @@ package trees
import (
"github.com/sjwhitworth/golearn/base"
"math"
"sort"
)
//
// Information gain rule generator
//
// InformationGainRuleGenerator generates DecisionTreeRules which
// maximize information gain at each node.
type InformationGainRuleGenerator struct {
}
// GenerateSplitAttribute returns the non-class Attribute which maximises the
// information gain.
// GenerateSplitRule returns a DecisionTreeNode based on a non-class Attribute
// which maximises the information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
func (r *InformationGainRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitAttributeFromSelection(candidates, f)
return r.GetSplitRuleFromSelection(candidates, f)
}
// GetSplitAttributeFromSelection returns the class Attribute which maximises
// the information gain amongst consideredAttributes
// GetSplitRuleFromSelection returns a DecisionTreeRule which maximises
// the information gain amongst the considered Attributes.
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
var selectedAttribute base.Attribute
@ -43,6 +46,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedVal := math.Inf(1)
// Compute the base entropy
classDist := base.GetClassDistribution(f)
@ -50,23 +54,98 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
var informationGain float64
var splitVal float64
if fAttr, ok := s.(*base.FloatAttribute); ok {
var attributeEntropy float64
attributeEntropy, splitVal = getNumericAttributeEntropy(f, fAttr)
informationGain = baseEntropy - attributeEntropy
} else {
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy := getSplitEntropy(proposedClassDist)
informationGain = baseEntropy - localEntropy
}
if informationGain > maxGain {
maxGain = informationGain
selectedAttribute = s
selectedVal = splitVal
}
}
// Pick the one which maximises IG
return selectedAttribute
return &DecisionTreeRule{selectedAttribute, selectedVal}
}
//
// Entropy functions
//
type numericSplitRef struct {
val float64
class string
}
type splitVec []numericSplitRef
func (a splitVec) Len() int { return len(a) }
func (a splitVec) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a splitVec) Less(i, j int) bool { return a[i].val < a[j].val }
func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {
// Resolve Attribute
attrSpec, err := f.GetAttribute(attr)
if err != nil {
panic(err)
}
// Build sortable vector
_, rows := f.Size()
refs := make([]numericSplitRef, rows)
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
cls := base.GetClass(f, row)
v := base.UnpackBytesToFloat(val[0])
refs[row] = numericSplitRef{v, cls}
return true, nil
})
// Sort
sort.Sort(splitVec(refs))
generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
presplit := make(map[string]int)
postplit := make(map[string]int)
for _, i := range refs {
if i.val < val {
presplit[i.class]++
} else {
postplit[i.class]++
}
}
ret := make(map[string]map[string]int)
ret["0"] = presplit
ret["1"] = postplit
return ret
}
minSplitEntropy := math.Inf(1)
minSplitVal := math.Inf(1)
// Consider each possible function
for i := 0; i < len(refs)-1; i++ {
val := refs[i].val + refs[i+1].val
val /= 2
splitDist := generateCandidateSplitDistribution(val)
splitEntropy := getSplitEntropy(splitDist)
if splitEntropy < minSplitEntropy {
minSplitEntropy = splitEntropy
minSplitVal = val
}
}
return minSplitEntropy, minSplitVal
}
// getSplitEntropy determines the entropy of the target
// class distribution after splitting on an base.Attribute
func getSplitEntropy(s map[string]map[string]int) float64 {

113
trees/gini.go Normal file
View File

@ -0,0 +1,113 @@
package trees
import (
"github.com/sjwhitworth/golearn/base"
"math"
)
//
// Gini-coefficient rule generator
//
// GiniCoefficientRuleGenerator generates DecisionTreeRules which minimize
// the Geni impurity coefficient at each node.
type GiniCoefficientRuleGenerator struct {
}
// GenerateSplitRule returns the non-class Attribute-based DecisionTreeRule
// which maximises the information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (g *GiniCoefficientRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return g.GetSplitRuleFromSelection(candidates, f)
}
// GetSplitRuleFromSelection returns the DecisionTreeRule which maximises
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (g *GiniCoefficientRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
var selectedAttribute base.Attribute
var selectedVal float64
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
// Minimize the averagge Gini index
minGini := math.Inf(1)
for _, s := range consideredAttributes {
var proposedDist map[string]map[string]int
var splitVal float64
if fAttr, ok := s.(*base.FloatAttribute); ok {
_, splitVal = getNumericAttributeEntropy(f, fAttr)
proposedDist = base.GetClassDistributionAfterThreshold(f, fAttr, splitVal)
} else {
proposedDist = base.GetClassDistributionAfterSplit(f, s)
}
avgGini := computeAverageGiniIndex(proposedDist)
if avgGini < minGini {
minGini = avgGini
selectedAttribute = s
selectedVal = splitVal
}
}
return &DecisionTreeRule{selectedAttribute, selectedVal}
}
//
// Utility functions
//
// computeGini computes the Gini impurity measure
func computeGini(s map[string]int) float64 {
// Compute probability map
p := make(map[string]float64)
for i := range s {
if p[i] == 0 {
continue
}
p[i] = 1.0 / float64(p[i])
}
// Compute overall sum
sum := 0.0
for i := range p {
sum += p[i] * p[i]
}
return 1.0 - sum
}
// computeGiniImpurity computes the average Gini index of a
// proposed split
func computeAverageGiniIndex(s map[string]map[string]int) float64 {
// Figure out the total number of things in this map
total := 0
for i := range s {
for j := range s[i] {
total += s[i][j]
}
}
sum := 0.0
for i := range s {
subtotal := 0.0
for j := range s[i] {
subtotal += float64(s[i][j])
}
cf := subtotal / float64(total)
cf *= computeGini(s[i])
sum += cf
}
return sum
}

76
trees/gr.go Normal file
View File

@ -0,0 +1,76 @@
package trees
import (
"github.com/sjwhitworth/golearn/base"
"math"
)
//
// Information Gatio Ratio generator
//
// InformationGainRatioRuleGenerator generates DecisionTreeRules which
// maximise the InformationGain at each node.
type InformationGainRatioRuleGenerator struct {
}
// GenerateSplitRule returns a DecisionTreeRule which maximises information
// gain ratio considering every available Attribute.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRatioRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitRuleFromSelection(candidates, f)
}
// GetSplitRuleFromSelection returns the DecisionRule which maximizes information gain,
// considering only a subset of Attributes.
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRatioRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
var selectedAttribute base.Attribute
var selectedVal float64
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxRatio := math.Inf(-1)
// Compute the base entropy
classDist := base.GetClassDistribution(f)
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
var informationGain float64
var localEntropy float64
var splitVal float64
if fAttr, ok := s.(*base.FloatAttribute); ok {
localEntropy, splitVal = getNumericAttributeEntropy(f, fAttr)
} else {
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy = getSplitEntropy(proposedClassDist)
}
informationGain = baseEntropy - localEntropy
informationGainRatio := informationGain / localEntropy
if informationGainRatio > maxRatio {
maxRatio = informationGainRatio
selectedAttribute = s
selectedVal = splitVal
}
}
// Pick the one which maximises IG
return &DecisionTreeRule{selectedAttribute, selectedVal}
}

View File

@ -8,7 +8,7 @@ import (
"sort"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
// NodeType determines whether a DecisionTreeNode is a leaf or not.
type NodeType int
const (
@ -19,19 +19,33 @@ const (
)
// RuleGenerator implementations analyse instances and determine
// the best value to split on
// the best value to split on.
type RuleGenerator interface {
GenerateSplitAttribute(base.FixedDataGrid) base.Attribute
GenerateSplitRule(base.FixedDataGrid) *DecisionTreeRule
}
// DecisionTreeNode represents a given portion of a decision tree
// DecisionTreeRule represents the "decision" in "decision tree".
type DecisionTreeRule struct {
SplitAttr base.Attribute
SplitVal float64
}
// String prints a human-readable summary of this thing.
func (d *DecisionTreeRule) String() string {
if _, ok := d.SplitAttr.(*base.FloatAttribute); ok {
return fmt.Sprintf("DecisionTreeRule(%s <= %f)", d.SplitAttr.GetName(), d.SplitVal)
}
return fmt.Sprintf("DecisionTreeRule(%s)", d.SplitAttr.GetName())
}
// DecisionTreeNode represents a given portion of a decision tree.
type DecisionTreeNode struct {
Type NodeType
Children map[string]*DecisionTreeNode
SplitAttr base.Attribute
ClassDist map[string]int
Class string
ClassAttr base.Attribute
SplitRule *DecisionTreeRule
}
func getClassAttr(from base.FixedDataGrid) base.Attribute {
@ -54,10 +68,10 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
getClassAttr(from),
&DecisionTreeRule{nil, 0.0},
}
return ret
}
@ -79,10 +93,10 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
getClassAttr(from),
&DecisionTreeRule{nil, 0.0},
}
return ret
}
@ -91,27 +105,34 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode
ret := &DecisionTreeNode{
RuleNode,
nil,
nil,
classes,
maxClass,
getClassAttr(from),
nil,
}
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
if splitOnAttribute == nil {
// Generate the splitting rule
splitRule := with.GenerateSplitRule(from)
if splitRule == nil {
// Can't determine, just return what we have
return ret
}
// Split the attributes based on this attribute's value
splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute)
var splitInstances map[string]base.FixedDataGrid
if _, ok := splitRule.SplitAttr.(*base.FloatAttribute); ok {
splitInstances = base.DecomposeOnNumericAttributeThreshold(from,
splitRule.SplitAttr, splitRule.SplitVal)
} else {
splitInstances = base.DecomposeOnAttributeValues(from, splitRule.SplitAttr)
}
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferID3Tree(newInstances, with)
}
ret.SplitAttr = splitOnAttribute
ret.SplitRule = splitRule
return ret
}
@ -127,8 +148,8 @@ func (d *DecisionTreeNode) getNestedString(level int) string {
if d.Children == nil {
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
} else {
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
keys := make([]string, 0)
var keys []string
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitRule))
for k := range d.Children {
keys = append(keys, k)
}
@ -163,12 +184,12 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
if d.Children == nil {
return
}
if d.SplitAttr == nil {
if d.SplitRule == nil {
return
}
// Recursively prune children of this node
sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
sub := base.DecomposeOnAttributeValues(using, d.SplitRule.SplitAttr)
for k := range d.Children {
if sub[k] == nil {
continue
@ -214,17 +235,32 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
} else {
at := cur.SplitAttr
splitVal := cur.SplitRule.SplitVal
at := cur.SplitRule.SplitAttr
ats, err := what.GetAttribute(at)
if err != nil {
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic(err)
}
classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
var classVar string
if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
}
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
// Suspicious of this
var bestChild string
for c := range cur.Children {
bestChild = c
@ -252,27 +288,40 @@ type ID3DecisionTree struct {
base.BaseClassifier
Root *DecisionTreeNode
PruneSplit float64
Rule RuleGenerator
}
// NewID3DecisionTree returns a new ID3DecisionTree with the specified test-prune
// ratio. Of the ratio is less than 0.001, the tree isn't pruned
// ratio and InformationGain as the rule generator.
// If the ratio is less than 0.001, the tree isn't pruned.
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
return &ID3DecisionTree{
base.BaseClassifier{},
nil,
prune,
new(InformationGainRuleGenerator),
}
}
// NewID3DecisionTreeFromRule returns a new ID3DecisionTree with the specified test-prun
// ratio and the given rule gnereator.
func NewID3DecisionTreeFromRule(prune float64, rule RuleGenerator) *ID3DecisionTree {
return &ID3DecisionTree{
base.BaseClassifier{},
nil,
prune,
rule,
}
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(trainData, rule)
t.Root = InferID3Tree(trainData, t.Rule)
t.Root.Prune(testData)
} else {
t.Root = InferID3Tree(on, rule)
t.Root = InferID3Tree(on, t.Rule)
}
return nil
}

View File

@ -12,14 +12,15 @@ type RandomTreeRuleGenerator struct {
internalRule InformationGainRuleGenerator
}
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
// GenerateSplitRule returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
func (r *RandomTreeRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
var consideredAttributes []base.Attribute
// First step is to generate the random attributes that we'll consider
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
maximumAttribute := len(allAttributes)
consideredAttributes := make([]base.Attribute, 0)
attrCounter := 0
for {
@ -42,7 +43,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) b
attrCounter++
}
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
return r.internalRule.GetSplitRuleFromSelection(consideredAttributes, f)
}
// RandomTree builds a decision tree by considering a fixed number

View File

@ -4,12 +4,118 @@ import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
"testing"
. "github.com/smartystreets/goconvey/convey"
"math/rand"
"testing"
)
func TestRandomTreeClassification(t *testing.T) {
func verifyTreeClassification(trainData, testData base.FixedDataGrid) {
rand.Seed(44414515)
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
Convey("Using a RandomTreeRule", func() {
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
randomTreeRuleGenerator.Attributes = 2
root := InferID3Tree(trainData, randomTreeRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a InformationGainRule", func() {
informationGainRuleGenerator := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, informationGainRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a GiniCoefficientRuleGenerator", func() {
gRuleGen := new(GiniCoefficientRuleGenerator)
root := InferID3Tree(trainData, gRuleGen)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a InformationGainRatioRuleGenerator", func() {
gRuleGen := new(InformationGainRatioRuleGenerator)
root := InferID3Tree(trainData, gRuleGen)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
})
Convey("Using NewRandomTree to create the tree", func() {
root := NewRandomTree(2)
Convey("Fitting with the tree", func() {
err := root.Fit(trainData)
So(err, ShouldBeNil)
Convey("Predicting with the tree, *without* pruning first", func() {
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
Convey("Predicting with the tree, pruning first", func() {
root.Prune(testData)
predictions, err := root.Predict(testData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
})
})
})
})
}
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
@ -23,78 +129,18 @@ func TestRandomTreeClassification(t *testing.T) {
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
verifyTreeClassification(filteredTrainData, filteredTestData)
})
}
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
Convey("Using a RandomTreeRule", func() {
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
randomTreeRuleGenerator.Attributes = 2
root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator)
func TestRandomTreeClassificationWithoutDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a InformationGainRule", func() {
informationGainRuleGenerator := new(InformationGainRuleGenerator)
root := InferID3Tree(filteredTrainData, informationGainRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
})
Convey("Using NewRandomTree to create the tree", func() {
root := NewRandomTree(2)
Convey("Fitting with the tree", func() {
err = root.Fit(filteredTrainData)
So(err, ShouldBeNil)
Convey("Predicting with the tree, *without* pruning first", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
Convey("Predicting with the tree, pruning first", func() {
root.Prune(filteredTestData)
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
})
})
})
})
verifyTreeClassification(trainData, testData)
})
}
@ -136,9 +182,24 @@ func TestID3Inference(t *testing.T) {
})
}
func TestPRIVATEgetNumericAttributeEntropy(t *testing.T) {
Convey("Checking a particular split...", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/c45-numeric.csv", true)
So(err, ShouldBeNil)
Convey("Fetching the right Attribute", func() {
attr := base.GetAttributeByName(instances, "Attribute2")
So(attr, ShouldNotEqual, nil)
Convey("Finding the threshold...", func() {
_, threshold := getNumericAttributeEntropy(instances, attr.(*base.FloatAttribute))
So(threshold, ShouldAlmostEqual, 82.5)
})
})
})
}
func itBuildsTheCorrectDecisionTree(root *DecisionTreeNode) {
Convey("The root should be 'outlook'", func() {
So(root.SplitAttr.GetName(), ShouldEqual, "outlook")
So(root.SplitRule.SplitAttr.GetName(), ShouldEqual, "outlook")
})
sunny := root.Children["sunny"]
@ -146,13 +207,13 @@ func itBuildsTheCorrectDecisionTree(root *DecisionTreeNode) {
rainy := root.Children["rainy"]
Convey("After the 'sunny' node, the decision should split on 'humidity'", func() {
So(sunny.SplitAttr.GetName(), ShouldEqual, "humidity")
So(sunny.SplitRule.SplitAttr.GetName(), ShouldEqual, "humidity")
})
Convey("After the 'rainy' node, the decision should split on 'windy'", func() {
So(rainy.SplitAttr.GetName(), ShouldEqual, "windy")
So(rainy.SplitRule.SplitAttr.GetName(), ShouldEqual, "windy")
})
Convey("There should be no splits after the 'overcast' node", func() {
So(overcast.SplitAttr, ShouldBeNil)
So(overcast.SplitRule.SplitAttr, ShouldBeNil)
})
highHumidity := sunny.Children["high"]