mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
trees: merge from v2-instances
This commit is contained in:
parent
3b3b23a221
commit
c2d040af30
@ -17,35 +17,40 @@ type InformationGainRuleGenerator struct {
|
||||
//
|
||||
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
||||
// variable will panic()
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
allAttributes := make([]int, 0)
|
||||
for i := 0; i < f.Cols; i++ {
|
||||
if i != f.ClassIndex {
|
||||
allAttributes = append(allAttributes, i)
|
||||
}
|
||||
}
|
||||
return r.GetSplitAttributeFromSelection(allAttributes, f)
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
attrs := f.AllAttributes()
|
||||
classAttrs := f.AllClassAttributes()
|
||||
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
|
||||
|
||||
return r.GetSplitAttributeFromSelection(candidates, f)
|
||||
}
|
||||
|
||||
// GetSplitAttributeFromSelection returns the class Attribute which maximises
|
||||
// the information gain amongst consideredAttributes
|
||||
//
|
||||
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
var selectedAttribute base.Attribute
|
||||
|
||||
// Parameter check
|
||||
if len(consideredAttributes) == 0 {
|
||||
panic("More Attributes should be considered")
|
||||
}
|
||||
|
||||
// Next step is to compute the information gain at this node
|
||||
// for each randomly chosen attribute, and pick the one
|
||||
// which maximises it
|
||||
maxGain := math.Inf(-1)
|
||||
selectedAttribute := -1
|
||||
|
||||
// Compute the base entropy
|
||||
classDist := f.GetClassDistribution()
|
||||
classDist := base.GetClassDistribution(f)
|
||||
baseEntropy := getBaseEntropy(classDist)
|
||||
|
||||
// Compute the information gain for each attribute
|
||||
for _, s := range consideredAttributes {
|
||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
|
||||
localEntropy := getSplitEntropy(proposedClassDist)
|
||||
informationGain := baseEntropy - localEntropy
|
||||
if informationGain > maxGain {
|
||||
@ -55,7 +60,7 @@ func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(considered
|
||||
}
|
||||
|
||||
// Pick the one which maximises IG
|
||||
return f.GetAttr(selectedAttribute)
|
||||
return selectedAttribute
|
||||
}
|
||||
|
||||
//
|
||||
|
69
trees/id3.go
69
trees/id3.go
@ -21,7 +21,7 @@ const (
|
||||
// RuleGenerator implementations analyse instances and determine
|
||||
// the best value to split on
|
||||
type RuleGenerator interface {
|
||||
GenerateSplitAttribute(*base.Instances) base.Attribute
|
||||
GenerateSplitAttribute(base.FixedDataGrid) base.Attribute
|
||||
}
|
||||
|
||||
// DecisionTreeNode represents a given portion of a decision tree
|
||||
@ -31,14 +31,19 @@ type DecisionTreeNode struct {
|
||||
SplitAttr base.Attribute
|
||||
ClassDist map[string]int
|
||||
Class string
|
||||
ClassAttr *base.Attribute
|
||||
ClassAttr base.Attribute
|
||||
}
|
||||
|
||||
func getClassAttr(from base.FixedDataGrid) base.Attribute {
|
||||
allClassAttrs := from.AllClassAttributes()
|
||||
return allClassAttrs[0]
|
||||
}
|
||||
|
||||
// InferID3Tree builds a decision tree using a RuleGenerator
|
||||
// from a set of Instances (implements the ID3 algorithm)
|
||||
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
|
||||
// Count the number of classes at this node
|
||||
classes := from.CountClassValues()
|
||||
classes := base.GetClassDistribution(from)
|
||||
// If there's only one class, return a DecisionTreeLeaf with
|
||||
// the only class available
|
||||
if len(classes) == 1 {
|
||||
@ -52,7 +57,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
@ -69,28 +74,29 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
|
||||
// If there are no more Attributes left to split on,
|
||||
// return a DecisionTreeLeaf with the majority class
|
||||
if from.GetAttributeCount() == 2 {
|
||||
cols, _ := from.Size()
|
||||
if cols == 2 {
|
||||
ret := &DecisionTreeNode{
|
||||
LeafNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Generate a return structure
|
||||
ret := &DecisionTreeNode{
|
||||
RuleNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
getClassAttr(from),
|
||||
}
|
||||
|
||||
// Generate a return structure
|
||||
// Generate the splitting attribute
|
||||
splitOnAttribute := with.GenerateSplitAttribute(from)
|
||||
if splitOnAttribute == nil {
|
||||
@ -98,7 +104,7 @@ func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
return ret
|
||||
}
|
||||
// Split the attributes based on this attribute's value
|
||||
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
|
||||
splitInstances := base.DecomposeOnAttributeValues(from, splitOnAttribute)
|
||||
// Create new children from these attributes
|
||||
ret.Children = make(map[string]*DecisionTreeNode)
|
||||
for k := range splitInstances {
|
||||
@ -146,13 +152,13 @@ func (d *DecisionTreeNode) String() string {
|
||||
}
|
||||
|
||||
// computeAccuracy is a helper method for Prune()
|
||||
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
||||
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
|
||||
cf := eval.GetConfusionMatrix(from, predictions)
|
||||
return eval.GetAccuracy(cf)
|
||||
}
|
||||
|
||||
// Prune eliminates branches which hurt accuracy
|
||||
func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
|
||||
// If you're a leaf, you're already pruned
|
||||
if d.Children == nil {
|
||||
return
|
||||
@ -162,11 +168,15 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
}
|
||||
|
||||
// Recursively prune children of this node
|
||||
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
|
||||
sub := base.DecomposeOnAttributeValues(using, d.SplitAttr)
|
||||
for k := range d.Children {
|
||||
if sub[k] == nil {
|
||||
continue
|
||||
}
|
||||
subH, subV := sub[k].Size()
|
||||
if subH == 0 || subV == 0 {
|
||||
continue
|
||||
}
|
||||
d.Children[k].Prune(sub[k])
|
||||
}
|
||||
|
||||
@ -185,24 +195,30 @@ func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
}
|
||||
|
||||
// Predict outputs a base.Instances containing predictions from this tree
|
||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
outputAttrs := make([]base.Attribute, 1)
|
||||
outputAttrs[0] = what.GetClassAttr()
|
||||
predictions := base.NewInstances(outputAttrs, what.Rows)
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
predictions := base.GeneratePredictionVector(what)
|
||||
classAttr := getClassAttr(predictions)
|
||||
classAttrSpec, err := predictions.GetAttribute(classAttr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
|
||||
predAttrSpecs := base.ResolveAllAttributes(what, predAttrs)
|
||||
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
||||
cur := d
|
||||
for {
|
||||
if cur.Children == nil {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
|
||||
break
|
||||
} else {
|
||||
at := cur.SplitAttr
|
||||
j := what.GetAttrIndex(at)
|
||||
if j == -1 {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
ats, err := what.GetAttribute(at)
|
||||
if err != nil {
|
||||
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
|
||||
break
|
||||
}
|
||||
classVar := at.GetStringFromSysVal(what.Get(i, j))
|
||||
|
||||
classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
|
||||
if next, ok := cur.Children[classVar]; ok {
|
||||
cur = next
|
||||
} else {
|
||||
@ -217,7 +233,8 @@ func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return predictions
|
||||
}
|
||||
|
||||
@ -245,7 +262,7 @@ func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
||||
}
|
||||
|
||||
// Fit builds the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Fit(on *base.Instances) {
|
||||
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
if t.PruneSplit > 0.001 {
|
||||
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
|
||||
@ -257,7 +274,7 @@ func (t *ID3DecisionTree) Fit(on *base.Instances) {
|
||||
}
|
||||
|
||||
// Predict outputs predictions from the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
|
||||
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
return t.Root.Predict(what)
|
||||
}
|
||||
|
||||
|
@ -14,32 +14,32 @@ type RandomTreeRuleGenerator struct {
|
||||
|
||||
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
|
||||
// which maximises Information Gain
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f base.FixedDataGrid) base.Attribute {
|
||||
|
||||
// First step is to generate the random attributes that we'll consider
|
||||
maximumAttribute := f.GetAttributeCount()
|
||||
consideredAttributes := make([]int, r.Attributes)
|
||||
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
|
||||
maximumAttribute := len(allAttributes)
|
||||
consideredAttributes := make([]base.Attribute, 0)
|
||||
|
||||
attrCounter := 0
|
||||
for {
|
||||
if len(consideredAttributes) >= r.Attributes {
|
||||
break
|
||||
}
|
||||
selectedAttribute := rand.Intn(maximumAttribute)
|
||||
base.Logger.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
|
||||
if selectedAttribute != f.ClassIndex {
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a == selectedAttribute {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
selectedAttrIndex := rand.Intn(maximumAttribute)
|
||||
selectedAttribute := allAttributes[selectedAttrIndex]
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a.Equals(selectedAttribute) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
|
||||
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||
@ -67,12 +67,12 @@ func NewRandomTree(attrs int) *RandomTree {
|
||||
}
|
||||
|
||||
// Fit builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||
func (rt *RandomTree) Fit(from base.FixedDataGrid) {
|
||||
rt.Root = InferID3Tree(from, rt.Rule)
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
|
||||
func (rt *RandomTree) Predict(from base.FixedDataGrid) base.FixedDataGrid {
|
||||
return rt.Root.Predict(from)
|
||||
}
|
||||
|
||||
@ -83,6 +83,6 @@ func (rt *RandomTree) String() string {
|
||||
|
||||
// Prune removes nodes from the tree which are detrimental
|
||||
// to determining the accuracy of the test set (with)
|
||||
func (rt *RandomTree) Prune(with *base.Instances) {
|
||||
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
|
||||
rt.Root.Prune(with)
|
||||
}
|
||||
|
@ -14,15 +14,17 @@ func TestRandomTree(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(inst, r)
|
||||
fmt.Println(instf)
|
||||
root := InferID3Tree(instf, r)
|
||||
fmt.Println(root)
|
||||
}
|
||||
|
||||
@ -33,18 +35,20 @@ func TestRandomTreeClassification(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(trainData)
|
||||
filt.Run(testData)
|
||||
fmt.Println(inst)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(trainData, r)
|
||||
root := InferID3Tree(trainDataF, r)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -58,17 +62,19 @@ func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
root := NewRandomTree(2)
|
||||
root.Fit(trainData)
|
||||
root.Fit(trainDataF)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -82,19 +88,21 @@ func TestPruning(testEnv *testing.T) {
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
||||
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
||||
|
||||
root := NewRandomTree(2)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
||||
root.Fit(fittrainData)
|
||||
root.Prune(fittestData)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
predictions := root.Predict(testDataF)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testDataF, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
@ -142,6 +150,7 @@ func TestID3Inference(testEnv *testing.T) {
|
||||
testEnv.Error(sunnyChild)
|
||||
}
|
||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||
fmt.Println(rainyChild.SplitAttr)
|
||||
testEnv.Error(rainyChild)
|
||||
}
|
||||
if overcastChild.SplitAttr != nil {
|
||||
@ -156,7 +165,6 @@ func TestID3Inference(testEnv *testing.T) {
|
||||
if sunnyLeafNormal.Class != "yes" {
|
||||
testEnv.Error(sunnyLeafNormal)
|
||||
}
|
||||
|
||||
windyLeafFalse := rainyChild.Children["false"]
|
||||
windyLeafTrue := rainyChild.Children["true"]
|
||||
if windyLeafFalse.Class != "yes" {
|
||||
@ -176,12 +184,18 @@ func TestID3Classification(testEnv *testing.T) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
for _, a := range base.NonClassFloatAttributes(inst) {
|
||||
filt.AddAttribute(a)
|
||||
}
|
||||
filt.Train()
|
||||
fmt.Println(filt)
|
||||
instf := base.NewLazilyFilteredInstances(inst, filt)
|
||||
fmt.Println("INSTFA", instf.AllAttributes())
|
||||
fmt.Println("INSTF", instf)
|
||||
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
|
||||
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(trainData, rule)
|
||||
@ -199,6 +213,7 @@ func TestID3(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
fmt.Println(inst)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user