From c2d040af3097390472111c9dddf25424d8a5604e Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sat, 2 Aug 2014 16:22:15 +0100 Subject: [PATCH] trees: merge from v2-instances --- trees/entropy.go | 31 ++++++++++------- trees/id3.go | 69 ++++++++++++++++++++++-------------- trees/random.go | 40 ++++++++++----------- trees/tree_test.go | 87 +++++++++++++++++++++++++++------------------- 4 files changed, 132 insertions(+), 95 deletions(-) diff --git a/trees/entropy.go b/trees/entropy.go index 958107a..1d7d254 100644 --- a/trees/entropy.go +++ b/trees/entropy.go @@ -17,35 +17,40 @@ type InformationGainRuleGenerator struct { // // IMPORTANT: passing a base.Instances with no Attributes other than the class // variable will panic() -func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute { - allAttributes := make([]int, 0) - for i := 0; i < f.Cols; i++ { - if i != f.ClassIndex { - allAttributes = append(allAttributes, i) - } - } - return r.GetSplitAttributeFromSelection(allAttributes, f) +func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute { + + attrs := f.AllAttributes() + classAttrs := f.AllClassAttributes() + candidates := base.AttributeDifferenceReferences(attrs, classAttrs) + + return r.GetSplitAttributeFromSelection(candidates, f) } // GetSplitAttributeFromSelection returns the class Attribute which maximises // the information gain amongst consideredAttributes // // IMPORTANT: passing a zero-length consideredAttributes parameter will panic() -func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute { +func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute { + + var selectedAttribute base.Attribute + + // Parameter check + if len(consideredAttributes) == 0 { + panic("More Attributes should be considered") + } // Next step is to compute the information gain at this node // for each randomly chosen attribute, and pick the one // which maximises it maxGain := math.Inf(-1) - selectedAttribute := -1 // Compute the base entropy - classDist := f.GetClassDistribution() + classDist := base.GetClassDistribution(f) baseEntropy := getBaseEntropy(classDist) // Compute the information gain for each attribute for _, s := range consideredAttributes { - proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s)) + proposedClassDist := base.GetClassDistributionAfterSplit(f, s) localEntropy := getSplitEntropy(proposedClassDist) informationGain := baseEntropy - localEntropy if informationGain > maxGain { @@ -55,7 +60,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered } // Pick the one which maximises IG - return f.GetAttr(selectedAttribute) + return selectedAttribute } // diff --git a/trees/id3.go b/trees/id3.go index 5a88faa..af59494 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -21,7 +21,7 @@ const ( // RuleGenerator implementations analyse instances and determine // the best value to split on type RuleGenerator interface { - GenerateSplitAttribute(*base.Instances) base.Attribute + GenerateSplitAttribute(base.FixedDataGrid) base.Attribute } // DecisionTreeNode represents a given portion of a decision tree @@ -31,14 +31,19 @@ type DecisionTreeNode struct { SplitAttr base.Attribute ClassDist map[string]int Class string - ClassAttr *base.Attribute + ClassAttr base.Attribute +} + +func getClassAttr(from base.FixedDataGrid) base.Attribute { + allClassAttrs := from.AllClassAttributes() + return allClassAttrs[0] } // InferID3Tree builds a decision tree using a RuleGenerator // from a set of Instances (implements the ID3 algorithm) -func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { +func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode { // Count the number of classes at this node - classes := from.CountClassValues() + classes := base.GetClassDistribution(from) // If there's only one class, return a DecisionTreeLeaf with // the only class available if len(classes) == 1 { @@ -52,7 +57,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { nil, classes, maxClass, - from.GetClassAttrPtr(), + getClassAttr(from), } return ret } @@ -69,28 +74,29 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { // If there are no more Attributes left to split on, // return a DecisionTreeLeaf with the majority class - if from.GetAttributeCount() == 2 { + cols, _ := from.Size() + if cols == 2 { ret := &DecisionTreeNode{ LeafNode, nil, nil, classes, maxClass, - from.GetClassAttrPtr(), + getClassAttr(from), } return ret } + // Generate a return structure ret := &DecisionTreeNode{ RuleNode, nil, nil, classes, maxClass, - from.GetClassAttrPtr(), + getClassAttr(from), } - // Generate a return structure // Generate the splitting attribute splitOnAttribute := with.GenerateSplitAttribute(from) if splitOnAttribute == nil { @@ -98,7 +104,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode { return ret } // Split the attributes based on this attribute's value - splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute) + splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute) // Create new children from these attributes ret.Children = make(map[string]*DecisionTreeNode) for k := range splitInstances { @@ -146,13 +152,13 @@ func (d *DecisionTreeNode) String() string { } // computeAccuracy is a helper method for Prune() -func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 { +func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 { cf := eval.GetConfusionMatrix(from, predictions) return eval.GetAccuracy(cf) } // Prune eliminates branches which hurt accuracy -func (d *DecisionTreeNode) Prune(using *base.Instances) { +func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) { // If you're a leaf, you're already pruned if d.Children == nil { return @@ -162,11 +168,15 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) { } // Recursively prune children of this node - sub := using.DecomposeOnAttributeValues(d.SplitAttr) + sub := base.DecomposeOnAttributeValues(using, d.SplitAttr) for k := range d.Children { if sub[k] == nil { continue } + subH, subV := sub[k].Size() + if subH == 0 || subV == 0 { + continue + } d.Children[k].Prune(sub[k]) } @@ -185,24 +195,30 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) { } // Predict outputs a base.Instances containing predictions from this tree -func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { - outputAttrs := make([]base.Attribute, 1) - outputAttrs[0] = what.GetClassAttr() - predictions := base.NewInstances(outputAttrs, what.Rows) - for i := 0; i < what.Rows; i++ { +func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid { + predictions := base.GeneratePredictionVector(what) + classAttr := getClassAttr(predictions) + classAttrSpec, err := predictions.GetAttribute(classAttr) + if err != nil { + panic(err) + } + predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes()) + predAttrSpecs := base.ResolveAllAttributes(what, predAttrs) + what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) { cur := d for { if cur.Children == nil { - predictions.SetAttrStr(i, 0, cur.Class) + predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } else { at := cur.SplitAttr - j := what.GetAttrIndex(at) - if j == -1 { - predictions.SetAttrStr(i, 0, cur.Class) + ats, err := what.GetAttribute(at) + if err != nil { + predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class)) break } - classVar := at.GetStringFromSysVal(what.Get(i, j)) + + classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo)) if next, ok := cur.Children[classVar]; ok { cur = next } else { @@ -217,7 +233,8 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances { } } } - } + return true, nil + }) return predictions } @@ -245,7 +262,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree { } // Fit builds the ID3 decision tree -func (t *ID3DecisionTree) Fit(on *base.Instances) { +func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) { rule := new(InformationGainRuleGenerator) if t.PruneSplit > 0.001 { trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit) @@ -257,7 +274,7 @@ func (t *ID3DecisionTree) Fit(on *base.Instances) { } // Predict outputs predictions from the ID3 decision tree -func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances { +func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid { return t.Root.Predict(what) } diff --git a/trees/random.go b/trees/random.go index dfe9261..9b6d1d5 100644 --- a/trees/random.go +++ b/trees/random.go @@ -14,32 +14,32 @@ type RandomTreeRuleGenerator struct { // GenerateSplitAttribute returns the best attribute out of those randomly chosen // which maximises Information Gain -func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute { +func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute { // First step is to generate the random attributes that we'll consider - maximumAttribute := f.GetAttributeCount() - consideredAttributes := make([]int, r.Attributes) + allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes()) + maximumAttribute := len(allAttributes) + consideredAttributes := make([]base.Attribute, 0) + attrCounter := 0 for { if len(consideredAttributes) >= r.Attributes { break } - selectedAttribute := rand.Intn(maximumAttribute) - base.Logger.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes)) - if selectedAttribute != f.ClassIndex { - matched := false - for _, a := range consideredAttributes { - if a == selectedAttribute { - matched = true - break - } + selectedAttrIndex := rand.Intn(maximumAttribute) + selectedAttribute := allAttributes[selectedAttrIndex] + matched := false + for _, a := range consideredAttributes { + if a.Equals(selectedAttribute) { + matched = true + break } - if matched { - continue - } - consideredAttributes = append(consideredAttributes, selectedAttribute) - attrCounter++ } + if matched { + continue + } + consideredAttributes = append(consideredAttributes, selectedAttribute) + attrCounter++ } return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f) @@ -67,12 +67,12 @@ func NewRandomTree(attrs int) *RandomTree { } // Fit builds a RandomTree suitable for prediction -func (rt *RandomTree) Fit(from *base.Instances) { +func (rt *RandomTree) Fit(from base.FixedDataGrid) { rt.Root = InferID3Tree(from, rt.Rule) } // Predict returns a set of Instances containing predictions -func (rt *RandomTree) Predict(from *base.Instances) *base.Instances { +func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid { return rt.Root.Predict(from) } @@ -83,6 +83,6 @@ func (rt *RandomTree) String() string { // Prune removes nodes from the tree which are detrimental // to determining the accuracy of the test set (with) -func (rt *RandomTree) Prune(with *base.Instances) { +func (rt *RandomTree) Prune(with base.FixedDataGrid) { rt.Root.Prune(with) } diff --git a/trees/tree_test.go b/trees/tree_test.go index e7ca49b..26b643f 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -14,15 +14,17 @@ func TestRandomTree(testEnv *testing.T) { if err != nil { panic(err) } - filt := filters.NewChiMergeFilter(inst, 0.90) - filt.AddAllNumericAttributes() - filt.Build() - filt.Run(inst) - fmt.Println(inst) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + instf := base.NewLazilyFilteredInstances(inst, filt) + r := new(RandomTreeRuleGenerator) r.Attributes = 2 - root := InferID3Tree(inst, r) + fmt.Println(instf) + root := InferID3Tree(instf, r) fmt.Println(root) } @@ -33,18 +35,20 @@ func TestRandomTreeClassification(testEnv *testing.T) { } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) filt := filters.NewChiMergeFilter(inst, 0.90) - filt.AddAllNumericAttributes() - filt.Build() - filt.Run(trainData) - filt.Run(testData) - fmt.Println(inst) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + trainDataF := base.NewLazilyFilteredInstances(trainData, filt) + testDataF := base.NewLazilyFilteredInstances(testData, filt) + r := new(RandomTreeRuleGenerator) r.Attributes = 2 - root := InferID3Tree(trainData, r) + root := InferID3Tree(trainDataF, r) fmt.Println(root) - predictions := root.Predict(testData) + predictions := root.Predict(testDataF) fmt.Println(predictions) - confusionMat := eval.GetConfusionMatrix(testData, predictions) + confusionMat := eval.GetConfusionMatrix(testDataF, predictions) fmt.Println(confusionMat) fmt.Println(eval.GetMacroPrecision(confusionMat)) fmt.Println(eval.GetMacroRecall(confusionMat)) @@ -58,17 +62,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) { } trainData, testData := base.InstancesTrainTestSplit(inst, 0.4) filt := filters.NewChiMergeFilter(inst, 0.90) - filt.AddAllNumericAttributes() - filt.Build() - fmt.Println(testData) - filt.Run(testData) - filt.Run(trainData) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + trainDataF := base.NewLazilyFilteredInstances(trainData, filt) + testDataF := base.NewLazilyFilteredInstances(testData, filt) + root := NewRandomTree(2) - root.Fit(trainData) + root.Fit(trainDataF) fmt.Println(root) - predictions := root.Predict(testData) + predictions := root.Predict(testDataF) fmt.Println(predictions) - confusionMat := eval.GetConfusionMatrix(testData, predictions) + confusionMat := eval.GetConfusionMatrix(testDataF, predictions) fmt.Println(confusionMat) fmt.Println(eval.GetMacroPrecision(confusionMat)) fmt.Println(eval.GetMacroRecall(confusionMat)) @@ -82,19 +88,21 @@ func TestPruning(testEnv *testing.T) { } trainData, testData := base.InstancesTrainTestSplit(inst, 0.6) filt := filters.NewChiMergeFilter(inst, 0.90) - filt.AddAllNumericAttributes() - filt.Build() - fmt.Println(testData) - filt.Run(testData) - filt.Run(trainData) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + trainDataF := base.NewLazilyFilteredInstances(trainData, filt) + testDataF := base.NewLazilyFilteredInstances(testData, filt) + root := NewRandomTree(2) - fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6) + fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6) root.Fit(fittrainData) root.Prune(fittestData) fmt.Println(root) - predictions := root.Predict(testData) + predictions := root.Predict(testDataF) fmt.Println(predictions) - confusionMat := eval.GetConfusionMatrix(testData, predictions) + confusionMat := eval.GetConfusionMatrix(testDataF, predictions) fmt.Println(confusionMat) fmt.Println(eval.GetMacroPrecision(confusionMat)) fmt.Println(eval.GetMacroRecall(confusionMat)) @@ -142,6 +150,7 @@ func TestID3Inference(testEnv *testing.T) { testEnv.Error(sunnyChild) } if rainyChild.SplitAttr.GetName() != "windy" { + fmt.Println(rainyChild.SplitAttr) testEnv.Error(rainyChild) } if overcastChild.SplitAttr != nil { @@ -156,7 +165,6 @@ func TestID3Inference(testEnv *testing.T) { if sunnyLeafNormal.Class != "yes" { testEnv.Error(sunnyLeafNormal) } - windyLeafFalse := rainyChild.Children["false"] windyLeafTrue := rainyChild.Children["true"] if windyLeafFalse.Class != "yes" { @@ -176,12 +184,18 @@ func TestID3Classification(testEnv *testing.T) { if err != nil { panic(err) } - filt := filters.NewBinningFilter(inst, 10) - filt.AddAllNumericAttributes() - filt.Build() - filt.Run(inst) fmt.Println(inst) - trainData, testData := base.InstancesTrainTestSplit(inst, 0.70) + filt := filters.NewBinningFilter(inst, 10) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + fmt.Println(filt) + instf := base.NewLazilyFilteredInstances(inst, filt) + fmt.Println("INSTFA", instf.AllAttributes()) + fmt.Println("INSTF", instf) + trainData, testData := base.InstancesTrainTestSplit(instf, 0.70) + // Build the decision tree rule := new(InformationGainRuleGenerator) root := InferID3Tree(trainData, rule) @@ -199,6 +213,7 @@ func TestID3(testEnv *testing.T) { // Import the "PlayTennis" dataset inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) + fmt.Println(inst) if err != nil { panic(err) }