1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-05-01 22:18:10 +08:00
golearn/trees/tree.go
Richard Townsend fdb67a4355 Initial work on decision trees
Random Forest has occasional disastrous accuracy:
	 never seen that happen in WEKA
2014-05-14 14:00:22 +01:00

142 lines
3.4 KiB
Go

package trees
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"strings"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
type NodeType int
const (
// LeafNode means there are no children
LeafNode NodeType = 1
// RuleNode means we should look at the next attribute value
RuleNode NodeType = 2
)
// RuleGenerator implementations analyse instances and determine
// the best value to split on
type RuleGenerator interface {
GenerateSplitAttribute(*base.Instances) base.Attribute
}
// DecisionTreeNode represents a given portion of a decision tree
type DecisionTreeNode struct {
Type NodeType
Children map[string]*DecisionTreeNode
SplitAttr base.Attribute
ClassDist map[string]int
Class string
ClassAttr *base.Attribute
}
// InferDecisionTree builds a decision tree using a RuleGenerator
// from a set of Instances
func InferDecisionTree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := from.CountClassValues()
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len(classes) == 1 {
maxClass := ""
for i := range classes {
maxClass = i
}
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
return ret
}
// Only have the class attribute
maxVal := 0
maxClass := ""
for i := range classes {
if classes[i] > maxVal {
maxClass = i
maxVal = classes[i]
}
}
// If there are no more attribute left to split on,
// return a DecisionTreeLeaf with the majority class
if from.GetAttributeCount() == 1 {
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
return ret
}
ret := &DecisionTreeNode{
RuleNode,
make(map[string]*DecisionTreeNode),
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
// Generate a return structure
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
// Split the attributes based on this attribute's value
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
// Create new children from these attributes
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferDecisionTree(newInstances, with)
}
ret.SplitAttr = splitOnAttribute
return ret
}
// String returns a human-readable representation of a given node
// and it's children
func (d *DecisionTreeNode) String() string {
children := make([]string, 0)
if d.Children != nil {
for k := range d.Children {
childStr := fmt.Sprintf("Rule(%s -> %s)", k, d.Children[k])
children = append(children, childStr)
}
return fmt.Sprintf("(%s(%s))", d.SplitAttr, strings.Join(children, "\n\t"))
}
return fmt.Sprintf("Leaf(%s (%s))", d.Class, d.ClassDist)
}
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
outputAttrs := make([]base.Attribute, 1)
outputAttrs[0] = what.GetClassAttr()
predictions := base.NewInstances(outputAttrs, what.Rows)
for i := 0; i < what.Rows; i++ {
cur := d
for j := 0; j < what.Cols; j++ {
at := what.GetAttr(j)
classVar := at.GetStringFromSysVal(what.Get(i, j))
if cur.Children == nil {
predictions.SetAttrStr(i, 0, cur.Class)
} else {
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
predictions.SetAttrStr(i, 0, cur.Class)
}
}
}
}
return predictions
}