diff --git a/base/util_instances.go b/base/util_instances.go index 3239b6c..b11e6a4 100644 --- a/base/util_instances.go +++ b/base/util_instances.go @@ -78,6 +78,17 @@ func SetClass(at UpdatableDataGrid, row int, class string) { at.Set(classAttrSpec, row, classBytes) } +// GetAttributeByName returns an Attribute matching a given name. +// Returns nil if one doesn't exist. +func GetAttributeByName(inst FixedDataGrid, name string) Attribute { + for _, a := range inst.AllAttributes() { + if a.GetName() == name { + return a + } + } + return nil +} + // GetClassDistribution returns a map containing the count of each // class type (indexed by the class' string representation). func GetClassDistribution(inst FixedDataGrid) map[string]int { @@ -90,6 +101,42 @@ func GetClassDistribution(inst FixedDataGrid) map[string]int { return ret } +// GetClassDistributionAfterThreshold returns the class distribution +// after a speculative split on a given Attribute using a threshold. +func GetClassDistributionAfterThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]map[string]int { + ret := make(map[string]map[string]int) + + // Find the attribute we're decomposing on + attrSpec, err := inst.GetAttribute(at) + if err != nil { + panic(fmt.Sprintf("Invalid attribute %s (%s)", at, err)) + } + + // Validate + if _, ok := at.(*FloatAttribute); !ok { + panic(fmt.Sprintf("Must be numeric!")) + } + + _, rows := inst.Size() + + for i := 0; i < rows; i++ { + splitVal := UnpackBytesToFloat(inst.Get(attrSpec, i)) > val + splitVar := "0" + if splitVal { + splitVar = "1" + } + classVar := GetClass(inst, i) + if _, ok := ret[splitVar]; !ok { + ret[splitVar] = make(map[string]int) + i-- + continue + } + ret[splitVar][classVar]++ + } + + return ret +} + // GetClassDistributionAfterSplit returns the class distribution // after a speculative split on a given Attribute. func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string]map[string]int { @@ -118,6 +165,64 @@ func GetClassDistributionAfterSplit(inst FixedDataGrid, at Attribute) map[string return ret } +// DecomposeOnNumericAttributeThreshold divides the instance set depending on the +// value of a given numeric Attribute, constructs child instances, and returns +// them in a map keyed on whether that row had a higher value than the threshold +// or not. +// +// IMPORTANT: calls panic() if the AttributeSpec of at cannot be determined, or if +// the Attribute is not numeric. +func DecomposeOnNumericAttributeThreshold(inst FixedDataGrid, at Attribute, val float64) map[string]FixedDataGrid { + // Verify + if _, ok := at.(*FloatAttribute); !ok { + panic("Invalid argument") + } + // Find the Attribute we're decomposing on + attrSpec, err := inst.GetAttribute(at) + if err != nil { + panic(fmt.Sprintf("Invalid Attribute index %s", at)) + } + // Construct the new Attribute set + newAttrs := make([]Attribute, 0) + for _, a := range inst.AllAttributes() { + if a.Equals(at) { + continue + } + newAttrs = append(newAttrs, a) + } + + // Create the return map + ret := make(map[string]FixedDataGrid) + + // Create the return row mapping + rowMaps := make(map[string][]int) + + // Build full Attribute set + fullAttrSpec := ResolveAttributes(inst, newAttrs) + fullAttrSpec = append(fullAttrSpec, attrSpec) + + // Decompose + inst.MapOverRows(fullAttrSpec, func(row [][]byte, rowNo int) (bool, error) { + // Find the output instance set + targetBytes := row[len(row)-1] + targetVal := UnpackBytesToFloat(targetBytes) + val := targetVal > val + targetSet := "0" + if val { + targetSet = "1" + } + rowMap := rowMaps[targetSet] + rowMaps[targetSet] = append(rowMap, rowNo) + return true, nil + }) + + for a := range rowMaps { + ret[a] = NewInstancesViewFromVisible(inst, rowMaps[a], newAttrs) + } + + return ret +} + // DecomposeOnAttributeValues divides the instance set depending on the // value of a given Attribute, constructs child instances, and returns // them in a map keyed on the string value of that Attribute. diff --git a/examples/datasets/c45-numeric.csv b/examples/datasets/c45-numeric.csv new file mode 100644 index 0000000..0164fda --- /dev/null +++ b/examples/datasets/c45-numeric.csv @@ -0,0 +1,15 @@ +Attribute1,Attribute2,Attribute3,Class +A,70,T,A +A,90,T,B +A,85,F,B +A,95,F,B +A,70,F,A +B,90,T,A +B,78,F,A +B,65,T,A +B,75,F,A +C,80,T,B +C,70,T,B +C,80,F,A +C,80,F,A +C,96,F,A diff --git a/examples/datasets/sources.txt b/examples/datasets/sources.txt new file mode 100644 index 0000000..2518577 --- /dev/null +++ b/examples/datasets/sources.txt @@ -0,0 +1,4 @@ +c45-numeric.csv: www.mgt.ncu.edu.tw/~wabble/School/C45.ppt +tennis.csv: "Machine Learning", Tom Mitchell, McGraw-Hill, 1997 (http://books.google.co.uk/books?id=xOGAngEACAAJ&dq=machine+learning,+mitchell&hl=en&sa=X&ei=zvpMVPz8IseN7Aa454DYBg&ved=0CFYQ6AEwBw) + + diff --git a/examples/trees/trees.go b/examples/trees/trees.go index 1c8c363..d8bcae2 100644 --- a/examples/trees/trees.go +++ b/examples/trees/trees.go @@ -10,14 +10,13 @@ import ( "github.com/sjwhitworth/golearn/filters" "github.com/sjwhitworth/golearn/trees" "math/rand" - "time" ) func main() { var tree base.Classifier - rand.Seed(time.Now().UTC().UnixNano()) + rand.Seed(44111342) // Load in the iris dataset iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) @@ -26,7 +25,7 @@ func main() { } // Discretise the iris dataset with Chi-Merge - filt := filters.NewChiMergeFilter(iris, 0.99) + filt := filters.NewChiMergeFilter(iris, 0.999) for _, a := range base.NonClassFloatAttributes(iris) { filt.AddAttribute(a) } @@ -55,13 +54,58 @@ func main() { } // Evaluate - fmt.Println("ID3 Performance") + fmt.Println("ID3 Performance (information gain)") cf, err := evaluation.GetConfusionMatrix(testData, predictions) if err != nil { panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) } fmt.Println(evaluation.GetSummary(cf)) + tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator)) + // (Parameter controls train-prune split.) + + // Train the ID3 tree + err = tree.Fit(trainData) + if err != nil { + panic(err) + } + + // Generate predictions + predictions, err = tree.Predict(testData) + if err != nil { + panic(err) + } + + // Evaluate + fmt.Println("ID3 Performance (information gain ratio)") + cf, err = evaluation.GetConfusionMatrix(testData, predictions) + if err != nil { + panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) + } + fmt.Println(evaluation.GetSummary(cf)) + + tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator)) + // (Parameter controls train-prune split.) + + // Train the ID3 tree + err = tree.Fit(trainData) + if err != nil { + panic(err) + } + + // Generate predictions + predictions, err = tree.Predict(testData) + if err != nil { + panic(err) + } + + // Evaluate + fmt.Println("ID3 Performance (gini index generator)") + cf, err = evaluation.GetConfusionMatrix(testData, predictions) + if err != nil { + panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) + } + fmt.Println(evaluation.GetSummary(cf)) // // Next up, Random Trees // @@ -86,7 +130,7 @@ func main() { // // Finally, Random Forests // - tree = ensemble.NewRandomForest(100, 3) + tree = ensemble.NewRandomForest(70, 3) err = tree.Fit(trainData) if err != nil { panic(err) diff --git a/trees/entropy.go b/trees/entropy.go index 9339fcd..4327163 100644 --- a/trees/entropy.go +++ b/trees/entropy.go @@ -3,34 +3,37 @@ package trees import ( "github.com/sjwhitworth/golearn/base" "math" + "sort" ) // // Information gain rule generator // +// InformationGainRuleGenerator generates DecisionTreeRules which +// maximize information gain at each node. type InformationGainRuleGenerator struct { } -// GenerateSplitAttribute returns the non-class Attribute which maximises the -// information gain. +// GenerateSplitRule returns a DecisionTreeNode based on a non-class Attribute +// which maximises the information gain. // // IMPORTANT: passing a base.Instances with no Attributes other than the class // variable will panic() -func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute { +func (r *InformationGainRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule { attrs := f.AllAttributes() classAttrs := f.AllClassAttributes() candidates := base.AttributeDifferenceReferences(attrs, classAttrs) - return r.GetSplitAttributeFromSelection(candidates, f) + return r.GetSplitRuleFromSelection(candidates, f) } -// GetSplitAttributeFromSelection returns the class Attribute which maximises -// the information gain amongst consideredAttributes +// GetSplitRuleFromSelection returns a DecisionTreeRule which maximises +// the information gain amongst the considered Attributes. // // IMPORTANT: passing a zero-length consideredAttributes parameter will panic() -func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute { +func (r *InformationGainRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule { var selectedAttribute base.Attribute @@ -43,6 +46,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered // for each randomly chosen attribute, and pick the one // which maximises it maxGain := math.Inf(-1) + selectedVal := math.Inf(1) // Compute the base entropy classDist := base.GetClassDistribution(f) @@ -50,23 +54,98 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered // Compute the information gain for each attribute for _, s := range consideredAttributes { - proposedClassDist := base.GetClassDistributionAfterSplit(f, s) - localEntropy := getSplitEntropy(proposedClassDist) - informationGain := baseEntropy - localEntropy + var informationGain float64 + var splitVal float64 + if fAttr, ok := s.(*base.FloatAttribute); ok { + var attributeEntropy float64 + attributeEntropy, splitVal = getNumericAttributeEntropy(f, fAttr) + informationGain = baseEntropy - attributeEntropy + } else { + proposedClassDist := base.GetClassDistributionAfterSplit(f, s) + localEntropy := getSplitEntropy(proposedClassDist) + informationGain = baseEntropy - localEntropy + } + if informationGain > maxGain { maxGain = informationGain selectedAttribute = s + selectedVal = splitVal } } // Pick the one which maximises IG - return selectedAttribute + return &DecisionTreeRule{selectedAttribute, selectedVal} } // // Entropy functions // +type numericSplitRef struct { + val float64 + class string +} + +type splitVec []numericSplitRef + +func (a splitVec) Len() int { return len(a) } +func (a splitVec) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a splitVec) Less(i, j int) bool { return a[i].val < a[j].val } + +func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) { + + // Resolve Attribute + attrSpec, err := f.GetAttribute(attr) + if err != nil { + panic(err) + } + + // Build sortable vector + _, rows := f.Size() + refs := make([]numericSplitRef, rows) + f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) { + cls := base.GetClass(f, row) + v := base.UnpackBytesToFloat(val[0]) + refs[row] = numericSplitRef{v, cls} + return true, nil + }) + + // Sort + sort.Sort(splitVec(refs)) + + generateCandidateSplitDistribution := func(val float64) map[string]map[string]int { + presplit := make(map[string]int) + postplit := make(map[string]int) + for _, i := range refs { + if i.val < val { + presplit[i.class]++ + } else { + postplit[i.class]++ + } + } + ret := make(map[string]map[string]int) + ret["0"] = presplit + ret["1"] = postplit + return ret + } + + minSplitEntropy := math.Inf(1) + minSplitVal := math.Inf(1) + // Consider each possible function + for i := 0; i < len(refs)-1; i++ { + val := refs[i].val + refs[i+1].val + val /= 2 + splitDist := generateCandidateSplitDistribution(val) + splitEntropy := getSplitEntropy(splitDist) + if splitEntropy < minSplitEntropy { + minSplitEntropy = splitEntropy + minSplitVal = val + } + } + + return minSplitEntropy, minSplitVal +} + // getSplitEntropy determines the entropy of the target // class distribution after splitting on an base.Attribute func getSplitEntropy(s map[string]map[string]int) float64 { diff --git a/trees/gini.go b/trees/gini.go new file mode 100644 index 0000000..23b1569 --- /dev/null +++ b/trees/gini.go @@ -0,0 +1,113 @@ +package trees + +import ( + "github.com/sjwhitworth/golearn/base" + "math" +) + +// +// Gini-coefficient rule generator +// + +// GiniCoefficientRuleGenerator generates DecisionTreeRules which minimize +// the Geni impurity coefficient at each node. +type GiniCoefficientRuleGenerator struct { +} + +// GenerateSplitRule returns the non-class Attribute-based DecisionTreeRule +// which maximises the information gain. +// +// IMPORTANT: passing a base.Instances with no Attributes other than the class +// variable will panic() +func (g *GiniCoefficientRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule { + + attrs := f.AllAttributes() + classAttrs := f.AllClassAttributes() + candidates := base.AttributeDifferenceReferences(attrs, classAttrs) + + return g.GetSplitRuleFromSelection(candidates, f) +} + +// GetSplitRuleFromSelection returns the DecisionTreeRule which maximises +// the information gain amongst consideredAttributes +// +// IMPORTANT: passing a zero-length consideredAttributes parameter will panic() +func (g *GiniCoefficientRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule { + + var selectedAttribute base.Attribute + var selectedVal float64 + + // Parameter check + if len(consideredAttributes) == 0 { + panic("More Attributes should be considered") + } + + // Minimize the averagge Gini index + minGini := math.Inf(1) + for _, s := range consideredAttributes { + var proposedDist map[string]map[string]int + var splitVal float64 + if fAttr, ok := s.(*base.FloatAttribute); ok { + _, splitVal = getNumericAttributeEntropy(f, fAttr) + proposedDist = base.GetClassDistributionAfterThreshold(f, fAttr, splitVal) + } else { + proposedDist = base.GetClassDistributionAfterSplit(f, s) + } + avgGini := computeAverageGiniIndex(proposedDist) + if avgGini < minGini { + minGini = avgGini + selectedAttribute = s + selectedVal = splitVal + } + } + + return &DecisionTreeRule{selectedAttribute, selectedVal} +} + +// +// Utility functions +// + +// computeGini computes the Gini impurity measure +func computeGini(s map[string]int) float64 { + // Compute probability map + p := make(map[string]float64) + for i := range s { + if p[i] == 0 { + continue + } + p[i] = 1.0 / float64(p[i]) + } + // Compute overall sum + sum := 0.0 + for i := range p { + sum += p[i] * p[i] + } + + return 1.0 - sum +} + +// computeGiniImpurity computes the average Gini index of a +// proposed split +func computeAverageGiniIndex(s map[string]map[string]int) float64 { + + // Figure out the total number of things in this map + total := 0 + for i := range s { + for j := range s[i] { + total += s[i][j] + } + } + + sum := 0.0 + for i := range s { + subtotal := 0.0 + for j := range s[i] { + subtotal += float64(s[i][j]) + } + cf := subtotal / float64(total) + cf *= computeGini(s[i]) + sum += cf + } + return sum +} diff --git a/trees/gr.go b/trees/gr.go new file mode 100644 index 0000000..46577f5 --- /dev/null +++ b/trees/gr.go @@ -0,0 +1,76 @@ +package trees + +import ( + "github.com/sjwhitworth/golearn/base" + "math" +) + +// +// Information Gatio Ratio generator +// + +// InformationGainRatioRuleGenerator generates DecisionTreeRules which +// maximise the InformationGain at each node. +type InformationGainRatioRuleGenerator struct { +} + +// GenerateSplitRule returns a DecisionTreeRule which maximises information +// gain ratio considering every available Attribute. +// +// IMPORTANT: passing a base.Instances with no Attributes other than the class +// variable will panic() +func (r *InformationGainRatioRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule { + + attrs := f.AllAttributes() + classAttrs := f.AllClassAttributes() + candidates := base.AttributeDifferenceReferences(attrs, classAttrs) + + return r.GetSplitRuleFromSelection(candidates, f) +} + +// GetSplitRuleFromSelection returns the DecisionRule which maximizes information gain, +// considering only a subset of Attributes. +// +// IMPORTANT: passing a zero-length consideredAttributes parameter will panic() +func (r *InformationGainRatioRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule { + + var selectedAttribute base.Attribute + var selectedVal float64 + + // Parameter check + if len(consideredAttributes) == 0 { + panic("More Attributes should be considered") + } + + // Next step is to compute the information gain at this node + // for each randomly chosen attribute, and pick the one + // which maximises it + maxRatio := math.Inf(-1) + + // Compute the base entropy + classDist := base.GetClassDistribution(f) + baseEntropy := getBaseEntropy(classDist) + + // Compute the information gain for each attribute + for _, s := range consideredAttributes { + var informationGain float64 + var localEntropy float64 + var splitVal float64 + if fAttr, ok := s.(*base.FloatAttribute); ok { + localEntropy, splitVal = getNumericAttributeEntropy(f, fAttr) + } else { + proposedClassDist := base.GetClassDistributionAfterSplit(f, s) + localEntropy = getSplitEntropy(proposedClassDist) + } + informationGain = baseEntropy - localEntropy + informationGainRatio := informationGain / localEntropy + if informationGainRatio > maxRatio { + maxRatio = informationGainRatio + selectedAttribute = s + selectedVal = splitVal + } + } + + // Pick the one which maximises IG + return &DecisionTreeRule{selectedAttribute, selectedVal} +} diff --git a/trees/id3.go b/trees/id3.go index 55a1f4d..26efc19 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -8,7 +8,7 @@ import ( "sort" ) -// NodeType determines whether a DecisionTreeNode is a leaf or not +// NodeType determines whether a DecisionTreeNode is a leaf or not. type NodeType int const ( @@ -19,19 +19,33 @@ const ( ) // RuleGenerator implementations analyse instances and determine -// the best value to split on +// the best value to split on. type RuleGenerator interface { - GenerateSplitAttribute(base.FixedDataGrid) base.Attribute + GenerateSplitRule(base.FixedDataGrid) *DecisionTreeRule } -// DecisionTreeNode represents a given portion of a decision tree +// DecisionTreeRule represents the "decision" in "decision tree". +type DecisionTreeRule struct { + SplitAttr base.Attribute + SplitVal float64 +} + +// String prints a human-readable summary of this thing. +func (d *DecisionTreeRule) String() string { + if _, ok := d.SplitAttr.(*base.FloatAttribute); ok { + return fmt.Sprintf("DecisionTreeRule(%s <= %f)", d.SplitAttr.GetName(), d.SplitVal) + } + return fmt.Sprintf("DecisionTreeRule(%s)", d.SplitAttr.GetName()) +} + +// DecisionTreeNode represents a given portion of a decision tree. type DecisionTreeNode struct { Type NodeType Children map[string]*DecisionTreeNode - SplitAttr base.Attribute ClassDist map[string]int Class string ClassAttr base.Attribute + SplitRule *DecisionTreeRule } func getClassAttr(from base.FixedDataGrid) base.Attribute { @@ -54,10 +68,10 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode ret := &DecisionTreeNode{ LeafNode, nil, - nil, classes, maxClass, getClassAttr(from), + &DecisionTreeRule{nil, 0.0}, } return ret } @@ -79,10 +93,10 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode ret := &DecisionTreeNode{ LeafNode, nil, - nil, classes, maxClass, getClassAttr(from), + &DecisionTreeRule{nil, 0.0}, } return ret } @@ -91,27 +105,34 @@ func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode ret := &DecisionTreeNode{ RuleNode, nil, - nil, classes, maxClass, getClassAttr(from), + nil, } - // Generate the splitting attribute - splitOnAttribute := with.GenerateSplitAttribute(from) - if splitOnAttribute == nil { + // Generate the splitting rule + splitRule := with.GenerateSplitRule(from) + if splitRule == nil { // Can't determine, just return what we have return ret } + // Split the attributes based on this attribute's value - splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute) + var splitInstances map[string]base.FixedDataGrid + if _, ok := splitRule.SplitAttr.(*base.FloatAttribute); ok { + splitInstances = base.DecomposeOnNumericAttributeThreshold(from, + splitRule.SplitAttr, splitRule.SplitVal) + } else { + splitInstances = base.DecomposeOnAttributeValues(from, splitRule.SplitAttr) + } // Create new children from these attributes ret.Children = make(map[string]*DecisionTreeNode) for k := range splitInstances { newInstances := splitInstances[k] ret.Children[k] = InferID3Tree(newInstances, with) } - ret.SplitAttr = splitOnAttribute + ret.SplitRule = splitRule return ret } @@ -127,8 +148,8 @@ func (d *DecisionTreeNode) getNestedString(level int) string { if d.Children == nil { buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class)) } else { - buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName())) - keys := make([]string, 0) + var keys []string + buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitRule)) for k := range d.Children { keys = append(keys, k) } @@ -163,12 +184,12 @@ func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) { if d.Children == nil { return } - if d.SplitAttr == nil { + if d.SplitRule == nil { return } // Recursively prune children of this node - sub := base.DecomposeOnAttributeValues(using, d.SplitAttr) + sub := base.DecomposeOnAttributeValues(using, d.SplitRule.SplitAttr) for k := range d.Children { if sub[k] == nil { continue @@ -214,17 +235,32 @@ func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } else { - at := cur.SplitAttr + splitVal := cur.SplitRule.SplitVal + at := cur.SplitRule.SplitAttr ats, err := what.GetAttribute(at) if err != nil { - predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) - break + //predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) + //break + panic(err) } - classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo)) + var classVar string + if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok { + // If it's a numeric Attribute (e.g. FloatAttribute) check that + // the value of the current node is greater than the old one + classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo)) + if classVal > splitVal { + classVar = "1" + } else { + classVar = "0" + } + } else { + classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo)) + } if next, ok := cur.Children[classVar]; ok { cur = next } else { + // Suspicious of this var bestChild string for c := range cur.Children { bestChild = c @@ -252,27 +288,40 @@ type ID3DecisionTree struct { base.BaseClassifier Root *DecisionTreeNode PruneSplit float64 + Rule RuleGenerator } // NewID3DecisionTree returns a new ID3DecisionTree with the specified test-prune -// ratio. Of the ratio is less than 0.001, the tree isn't pruned +// ratio and InformationGain as the rule generator. +// If the ratio is less than 0.001, the tree isn't pruned. func NewID3DecisionTree(prune float64) *ID3DecisionTree { return &ID3DecisionTree{ base.BaseClassifier{}, nil, prune, + new(InformationGainRuleGenerator), + } +} + +// NewID3DecisionTreeFromRule returns a new ID3DecisionTree with the specified test-prun +// ratio and the given rule gnereator. +func NewID3DecisionTreeFromRule(prune float64, rule RuleGenerator) *ID3DecisionTree { + return &ID3DecisionTree{ + base.BaseClassifier{}, + nil, + prune, + rule, } } // Fit builds the ID3 decision tree func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error { - rule := new(InformationGainRuleGenerator) if t.PruneSplit > 0.001 { trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit) - t.Root = InferID3Tree(trainData, rule) + t.Root = InferID3Tree(trainData, t.Rule) t.Root.Prune(testData) } else { - t.Root = InferID3Tree(on, rule) + t.Root = InferID3Tree(on, t.Rule) } return nil } diff --git a/trees/random.go b/trees/random.go index 891d50c..0c32a9e 100644 --- a/trees/random.go +++ b/trees/random.go @@ -12,14 +12,15 @@ type RandomTreeRuleGenerator struct { internalRule InformationGainRuleGenerator } -// GenerateSplitAttribute returns the best attribute out of those randomly chosen +// GenerateSplitRule returns the best attribute out of those randomly chosen // which maximises Information Gain -func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute { +func (r *RandomTreeRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule { + + var consideredAttributes []base.Attribute // First step is to generate the random attributes that we'll consider allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes()) maximumAttribute := len(allAttributes) - consideredAttributes := make([]base.Attribute, 0) attrCounter := 0 for { @@ -42,7 +43,7 @@ func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) b attrCounter++ } - return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f) + return r.internalRule.GetSplitRuleFromSelection(consideredAttributes, f) } // RandomTree builds a decision tree by considering a fixed number diff --git a/trees/tree_test.go b/trees/tree_test.go index d855fb2..06b10b1 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -4,12 +4,118 @@ import ( "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" - "testing" - . "github.com/smartystreets/goconvey/convey" + "math/rand" + "testing" ) -func TestRandomTreeClassification(t *testing.T) { +func verifyTreeClassification(trainData, testData base.FixedDataGrid) { + rand.Seed(44414515) + Convey("Using InferID3Tree to create the tree and do the fitting", func() { + Convey("Using a RandomTreeRule", func() { + randomTreeRuleGenerator := new(RandomTreeRuleGenerator) + randomTreeRuleGenerator.Attributes = 2 + root := InferID3Tree(trainData, randomTreeRuleGenerator) + + Convey("Predicting with the tree", func() { + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) + }) + }) + }) + + Convey("Using a InformationGainRule", func() { + informationGainRuleGenerator := new(InformationGainRuleGenerator) + root := InferID3Tree(trainData, informationGainRuleGenerator) + + Convey("Predicting with the tree", func() { + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) + }) + }) + }) + Convey("Using a GiniCoefficientRuleGenerator", func() { + gRuleGen := new(GiniCoefficientRuleGenerator) + root := InferID3Tree(trainData, gRuleGen) + Convey("Predicting with the tree", func() { + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) + }) + }) + }) + Convey("Using a InformationGainRatioRuleGenerator", func() { + gRuleGen := new(InformationGainRatioRuleGenerator) + root := InferID3Tree(trainData, gRuleGen) + Convey("Predicting with the tree", func() { + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) + }) + }) + }) + + }) + + Convey("Using NewRandomTree to create the tree", func() { + root := NewRandomTree(2) + + Convey("Fitting with the tree", func() { + err := root.Fit(trainData) + So(err, ShouldBeNil) + + Convey("Predicting with the tree, *without* pruning first", func() { + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) + }) + }) + + Convey("Predicting with the tree, pruning first", func() { + root.Prune(testData) + + predictions, err := root.Predict(testData) + So(err, ShouldBeNil) + + confusionMatrix, err := evaluation.GetConfusionMatrix(testData, predictions) + So(err, ShouldBeNil) + + Convey("Predictions should be somewhat accurate", func() { + So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4) + }) + }) + }) + }) +} + +func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) { Convey("Predictions on filtered data with a Random Tree", t, func() { instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) So(err, ShouldBeNil) @@ -23,78 +129,18 @@ func TestRandomTreeClassification(t *testing.T) { filter.Train() filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter) filteredTestData := base.NewLazilyFilteredInstances(testData, filter) + verifyTreeClassification(filteredTrainData, filteredTestData) + }) +} - Convey("Using InferID3Tree to create the tree and do the fitting", func() { - Convey("Using a RandomTreeRule", func() { - randomTreeRuleGenerator := new(RandomTreeRuleGenerator) - randomTreeRuleGenerator.Attributes = 2 - root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator) +func TestRandomTreeClassificationWithoutDiscretisation(t *testing.T) { + Convey("Predictions on filtered data with a Random Tree", t, func() { + instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + So(err, ShouldBeNil) - Convey("Predicting with the tree", func() { - predictions, err := root.Predict(filteredTestData) - So(err, ShouldBeNil) + trainData, testData := base.InstancesTrainTestSplit(instances, 0.6) - confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) - So(err, ShouldBeNil) - - Convey("Predictions should be somewhat accurate", func() { - So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) - }) - }) - }) - - Convey("Using a InformationGainRule", func() { - informationGainRuleGenerator := new(InformationGainRuleGenerator) - root := InferID3Tree(filteredTrainData, informationGainRuleGenerator) - - Convey("Predicting with the tree", func() { - predictions, err := root.Predict(filteredTestData) - So(err, ShouldBeNil) - - confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) - So(err, ShouldBeNil) - - Convey("Predictions should be somewhat accurate", func() { - So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) - }) - }) - }) - }) - - Convey("Using NewRandomTree to create the tree", func() { - root := NewRandomTree(2) - - Convey("Fitting with the tree", func() { - err = root.Fit(filteredTrainData) - So(err, ShouldBeNil) - - Convey("Predicting with the tree, *without* pruning first", func() { - predictions, err := root.Predict(filteredTestData) - So(err, ShouldBeNil) - - confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) - So(err, ShouldBeNil) - - Convey("Predictions should be somewhat accurate", func() { - So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5) - }) - }) - - Convey("Predicting with the tree, pruning first", func() { - root.Prune(filteredTestData) - - predictions, err := root.Predict(filteredTestData) - So(err, ShouldBeNil) - - confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions) - So(err, ShouldBeNil) - - Convey("Predictions should be somewhat accurate", func() { - So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4) - }) - }) - }) - }) + verifyTreeClassification(trainData, testData) }) } @@ -136,9 +182,24 @@ func TestID3Inference(t *testing.T) { }) } +func TestPRIVATEgetNumericAttributeEntropy(t *testing.T) { + Convey("Checking a particular split...", t, func() { + instances, err := base.ParseCSVToInstances("../examples/datasets/c45-numeric.csv", true) + So(err, ShouldBeNil) + Convey("Fetching the right Attribute", func() { + attr := base.GetAttributeByName(instances, "Attribute2") + So(attr, ShouldNotEqual, nil) + Convey("Finding the threshold...", func() { + _, threshold := getNumericAttributeEntropy(instances, attr.(*base.FloatAttribute)) + So(threshold, ShouldAlmostEqual, 82.5) + }) + }) + }) +} + func itBuildsTheCorrectDecisionTree(root *DecisionTreeNode) { Convey("The root should be 'outlook'", func() { - So(root.SplitAttr.GetName(), ShouldEqual, "outlook") + So(root.SplitRule.SplitAttr.GetName(), ShouldEqual, "outlook") }) sunny := root.Children["sunny"] @@ -146,13 +207,13 @@ func itBuildsTheCorrectDecisionTree(root *DecisionTreeNode) { rainy := root.Children["rainy"] Convey("After the 'sunny' node, the decision should split on 'humidity'", func() { - So(sunny.SplitAttr.GetName(), ShouldEqual, "humidity") + So(sunny.SplitRule.SplitAttr.GetName(), ShouldEqual, "humidity") }) Convey("After the 'rainy' node, the decision should split on 'windy'", func() { - So(rainy.SplitAttr.GetName(), ShouldEqual, "windy") + So(rainy.SplitRule.SplitAttr.GetName(), ShouldEqual, "windy") }) Convey("There should be no splits after the 'overcast' node", func() { - So(overcast.SplitAttr, ShouldBeNil) + So(overcast.SplitRule.SplitAttr, ShouldBeNil) }) highHumidity := sunny.Children["high"]