1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/trees/id3.go
2017-08-07 17:26:11 +01:00

659 lines
17 KiB
Go

package trees
import (
"bytes"
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"encoding/json"
"sort"
"log"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not.
type NodeType int
const (
// LeafNode means there are no children
LeafNode NodeType = 1
// RuleNode means we should look at the next attribute value
RuleNode NodeType = 2
)
// RuleGenerator implementations analyse instances and determine
// the best value to split on.
type RuleGenerator interface {
GenerateSplitRule(base.FixedDataGrid) *DecisionTreeRule
}
// DecisionTreeRule represents the "decision" in "decision tree".
type DecisionTreeRule struct {
SplitAttr base.Attribute `json:"split_attribute"`
SplitVal float64 `json:"split_val"`
}
func (d *DecisionTreeRule) MarshalJSON() ([]byte, error) {
ret := make(map[string]interface{})
marshaledSplitAttrRaw, err := d.SplitAttr.MarshalJSON()
if err != nil {
return nil, err
}
marshaledSplitAttr := make(map[string]interface{})
err = json.Unmarshal(marshaledSplitAttrRaw, &marshaledSplitAttr)
if err != nil {
panic(err)
}
ret["split_attribute"] = marshaledSplitAttr
ret["split_val"] = d.SplitVal
return json.Marshal(ret)
}
func (d *DecisionTreeRule) unmarshalJSON(data []byte) error {
var jsonMap map[string]interface{}
err := json.Unmarshal(data, &jsonMap)
if err != nil {
return err
}
if splitVal, ok := jsonMap["split_val"]; ok {
d.SplitVal = splitVal.(float64)
}
split := jsonMap["split_attribute"]
log.Printf("DecisionTreeRule/UnmarshalJSON: %v", split)
splitBytes, err := json.Marshal(split)
if err != nil {
panic(err)
}
d.SplitAttr, err = base.DeserializeAttribute(splitBytes)
if err != nil {
return err
}
if d.SplitAttr == nil {
panic("Should not be nil")
return fmt.Errorf("base.DeserializeAttribute returned nil")
}
return nil
}
func (d *DecisionTreeRule) UnmarshalJSON(data []byte) error {
log.Printf("DecisionTreeRule/UnmarshalJSON: input is %s", data)
ret := d.unmarshalJSON(data)
log.Printf("DecisionTreeRule/UnmarshalJSON: output is %s, self = %s", ret, d.String())
return ret
}
// String prints a human-readable summary of this thing.
func (d *DecisionTreeRule) String() string {
if (d.SplitAttr == nil) {
return fmt.Sprintf("INVALID:DecisionTreeRule(SplitAttr is nil)")
}
if _, ok := d.SplitAttr.(*base.FloatAttribute); ok {
return fmt.Sprintf("DecisionTreeRule(%s <= %f)", d.SplitAttr.GetName(), d.SplitVal)
}
return fmt.Sprintf("DecisionTreeRule(%s)", d.SplitAttr.GetName())
}
// DecisionTreeNode represents a given portion of a decision tree.
type DecisionTreeNode struct {
Type NodeType `json:"node_type"`
Children map[string]*DecisionTreeNode `json:"children"`
ClassDist map[string]int `json:"class_dist"`
Class string `json:"class_string"`
ClassAttr base.Attribute `json:"class_attribute"`
SplitRule *DecisionTreeRule `json:"decision_tree_rule"`
}
func getClassAttr(from base.FixedDataGrid) base.Attribute {
allClassAttrs := from.AllClassAttributes()
return allClassAttrs[0]
}
// MarshalJSON returns a JSON representation of this Attribute
// for serialisation.
func (d *DecisionTreeNode) MarshalJSON() ([]byte, error) {
ret := map[string]interface{}{
"type": d.Type,
"class_dist": d.ClassDist,
"class": d.Class,
}
if d.SplitRule != nil && d.SplitRule.SplitAttr != nil {
rawDRule, err := d.SplitRule.MarshalJSON()
if err != nil {
return nil, err
}
var dRule map[string]interface{}
err = json.Unmarshal(rawDRule, &dRule)
if err != nil {
panic(err)
}
ret["split_rule"] = dRule
}
rawClassAttr, err := d.ClassAttr.MarshalJSON()
if err != nil {
return nil, err
}
var classAttr map[string]interface{}
err = json.Unmarshal(rawClassAttr, &classAttr)
ret["class_attr"] = classAttr
if len(d.Children) > 0 {
children := make(map[string]interface{})
for k := range d.Children {
cur, err := d.Children[k].MarshalJSON()
if err != nil {
return nil, err
}
var child map[string]interface{}
err = json.Unmarshal(cur, &child)
if err != nil {
panic(err)
}
children[k] = child
}
ret["children"] = children
}
return json.Marshal(ret)
}
// UnmarshalJSON reads a JSON representation of this Attribute.
func (d *DecisionTreeNode) UnmarshalJSON(data []byte) error {
log.Printf("Input JSON: %s", string(data))
jsonMap := make(map[string]interface{})
err := json.Unmarshal(data, &jsonMap)
if err != nil {
return err
}
rawType := int(jsonMap["type"].(float64))
if rawType == 1 {
d.Type = LeafNode
} else if rawType == 2 {
d.Type = RuleNode
} else {
return fmt.Errorf("Unknown nodeType: %d", rawType)
}
//d.Type = NodeType(int(jsonMap["type"].(float64)))
// Convert the class distribution back
classDist := jsonMap["class_dist"].(map[string]interface{})
d.ClassDist = make(map[string]int)
for k := range classDist {
d.ClassDist[k] = int(classDist[k].(float64))
}
d.Class = jsonMap["class"].(string)
//
// Decode the class attribute
//
// Temporarily re-marshal this field back to bytes
rawClassAttr := jsonMap["class_attr"]
rawClassAttrBytes, err := json.Marshal(rawClassAttr)
if err != nil {
return err
}
classAttr, err := base.DeserializeAttribute(rawClassAttrBytes)
if err != nil {
return err
}
d.ClassAttr = classAttr
d.SplitRule = nil
log.Printf("Before adding children: %s", d.String())
if splitRule, ok := jsonMap["split_rule"]; ok {
d.SplitRule = &DecisionTreeRule{}
splitRuleBytes, err := json.Marshal(splitRule)
if err != nil {
panic(err)
}
err = d.SplitRule.UnmarshalJSON(splitRuleBytes)
if err != nil {
return err
}
d.Children = make(map[string]*DecisionTreeNode)
childMap := jsonMap["children"].(map[string]interface{})
for i := range childMap {
cur := &DecisionTreeNode{}
childBytes, err := json.Marshal(childMap[i])
if err != nil {
panic(err)
}
err = cur.UnmarshalJSON(childBytes)
if err != nil {
return err
}
d.Children[i] = cur
}
}
log.Printf("After adding children: %s", d.String())
return nil
}
// Save sends the classification tree to an output file
func (d *DecisionTreeNode) Save(filePath string) error {
metadata := base.ClassifierMetadataV1 {
FormatVersion: 1,
ClassifierName: "test",
ClassifierVersion: "1",
ClassifierMetadata: nil,
}
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
if err != nil {
return err
}
b, err := json.Marshal(d)
if err != nil {
return err
}
err = serializer.WriteBytesForKey("tree", b)
if err != nil {
return err
}
serializer.Close()
return nil
}
// Load reads from the classifier from an output file
func (d *DecisionTreeNode) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
defer func() {
reader.Close()
}()
b, err := reader.GetBytesForKey("tree")
if err != nil {
return err
}
err = json.Unmarshal(b, d)
if err != nil {
return err
}
return nil
}
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from base.FixedDataGrid, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := base.GetClassDistribution(from)
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len(classes) == 1 {
maxClass := ""
for i := range classes {
maxClass = i
}
ret := &DecisionTreeNode{
LeafNode,
nil,
classes,
maxClass,
getClassAttr(from),
&DecisionTreeRule{nil, 0.0},
}
return ret
}
// Only have the class attribute
maxVal := 0
maxClass := ""
for i := range classes {
if classes[i] > maxVal {
maxClass = i
maxVal = classes[i]
}
}
// If there are no more Attributes left to split on,
// return a DecisionTreeLeaf with the majority class
cols, _ := from.Size()
if cols == 2 {
ret := &DecisionTreeNode{
LeafNode,
nil,
classes,
maxClass,
getClassAttr(from),
&DecisionTreeRule{nil, 0.0},
}
return ret
}
// Generate a return structure
ret := &DecisionTreeNode{
RuleNode,
nil,
classes,
maxClass,
getClassAttr(from),
nil,
}
// Generate the splitting rule
splitRule := with.GenerateSplitRule(from)
if splitRule == nil || splitRule.SplitAttr == nil {
// Can't determine, just return what we have
return ret
}
// Split the attributes based on this attribute's value
var splitInstances map[string]base.FixedDataGrid
if _, ok := splitRule.SplitAttr.(*base.FloatAttribute); ok {
splitInstances = base.DecomposeOnNumericAttributeThreshold(from,
splitRule.SplitAttr, splitRule.SplitVal)
} else {
splitInstances = base.DecomposeOnAttributeValues(from, splitRule.SplitAttr)
}
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferID3Tree(newInstances, with)
}
ret.SplitRule = splitRule
return ret
}
// getNestedString returns the contents of node d
// prefixed by level number of tags (also prints children)
func (d *DecisionTreeNode) getNestedString(level int) string {
buf := bytes.NewBuffer(nil)
tmp := bytes.NewBuffer(nil)
for i := 0; i < level; i++ {
tmp.WriteString("\t")
}
buf.WriteString(tmp.String())
if d.Children == nil {
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
} else {
var keys []string
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitRule))
for k := range d.Children {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
buf.WriteString("\n")
buf.WriteString(tmp.String())
buf.WriteString("\t")
buf.WriteString(k)
buf.WriteString("\n")
buf.WriteString(d.Children[k].getNestedString(level + 1))
}
}
return buf.String()
}
// String returns a human-readable representation of a given node
// and it's children
func (d *DecisionTreeNode) String() string {
return d.getNestedString(0)
}
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions base.FixedDataGrid, from base.FixedDataGrid) float64 {
cf, _ := evaluation.GetConfusionMatrix(from, predictions)
return evaluation.GetAccuracy(cf)
}
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using base.FixedDataGrid) {
// If you're a leaf, you're already pruned
if d.Children == nil {
return
}
if d.SplitRule == nil {
return
}
// Recursively prune children of this node
sub := base.DecomposeOnAttributeValues(using, d.SplitRule.SplitAttr)
for k := range d.Children {
if sub[k] == nil {
continue
}
subH, subV := sub[k].Size()
if subH == 0 || subV == 0 {
continue
}
d.Children[k].Prune(sub[k])
}
// Get a baseline accuracy
predictions, _ := d.Predict(using)
baselineAccuracy := computeAccuracy(predictions, using)
// Speculatively remove the children and re-evaluate
tmpChildren := d.Children
d.Children = nil
predictions, _ = d.Predict(using)
newAccuracy := computeAccuracy(predictions, using)
// Keep the children removed if better, else restore
if newAccuracy < baselineAccuracy {
d.Children = tmpChildren
}
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
predictions := base.GeneratePredictionVector(what)
classAttr := getClassAttr(predictions)
classAttrSpec, err := predictions.GetAttribute(classAttr)
if err != nil {
panic(err)
}
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d
for {
if cur.Children == nil {
predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
break
} else {
splitVal := cur.SplitRule.SplitVal
at := cur.SplitRule.SplitAttr
ats, err := what.GetAttribute(at)
if err != nil {
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic(err)
}
var classVar string
if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
}
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
// Suspicious of this
var bestChild string
for c := range cur.Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur.Children[bestChild]
}
}
}
return true, nil
})
return predictions, nil
}
type ClassProba struct {
Probability float64
ClassValue string
}
type ClassesProba []ClassProba
func (o ClassesProba) Len() int {
return len(o)
}
func (o ClassesProba) Swap(i, j int) {
o[i], o[j] = o[j], o[i]
}
func (o ClassesProba) Less(i, j int) bool {
return o[i].Probability > o[j].Probability
}
// Predict class probabilities of the input samples what, returns a sorted array (by probability) of classes, and another array representing it's probabilities
func (t *ID3DecisionTree) PredictProba(what base.FixedDataGrid) (ClassesProba, error) {
d := t.Root
predictions := base.GeneratePredictionVector(what)
predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
predAttrSpecs := base.ResolveAttributes(what, predAttrs)
_, rowCount := what.Size()
if rowCount > 1 {
panic("PredictProba supports only 1 row predictions")
}
var results ClassesProba
what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
cur := d
for {
if cur.Children == nil {
totalDist := 0
for _,dist:= range cur.ClassDist {
totalDist += dist
}
for class,dist:= range cur.ClassDist {
classProba := ClassProba{ClassValue:class, Probability: float64(float64(dist)/float64(totalDist))}
results = append(results,classProba)
}
sort.Sort(results)
break
} else {
splitVal := cur.SplitRule.SplitVal
at := cur.SplitRule.SplitAttr
ats, err := what.GetAttribute(at)
if err != nil {
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic(err)
}
var classVar string
if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
}
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
// Suspicious of this
var bestChild string
for c := range cur.Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur.Children[bestChild]
}
}
}
return true, nil
})
return results, nil
}
//
// ID3 Tree type
//
// ID3DecisionTree represents an ID3-based decision tree
// using the Information Gain metric to select which attributes
// to split on at each node.
type ID3DecisionTree struct {
base.BaseClassifier
Root *DecisionTreeNode
PruneSplit float64
Rule RuleGenerator
}
// NewID3DecisionTree returns a new ID3DecisionTree with the specified test-prune
// ratio and InformationGain as the rule generator.
// If the ratio is less than 0.001, the tree isn't pruned.
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
return &ID3DecisionTree{
base.BaseClassifier{},
nil,
prune,
new(InformationGainRuleGenerator),
}
}
// NewID3DecisionTreeFromRule returns a new ID3DecisionTree with the specified test-prun
// ratio and the given rule gnereator.
func NewID3DecisionTreeFromRule(prune float64, rule RuleGenerator) *ID3DecisionTree {
return &ID3DecisionTree{
base.BaseClassifier{},
nil,
prune,
rule,
}
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(trainData, t.Rule)
t.Root.Prune(testData)
} else {
t.Root = InferID3Tree(on, t.Rule)
}
return nil
}
// Predict outputs predictions from the ID3 decision tree
func (t *ID3DecisionTree) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
return t.Root.Predict(what)
}
// String returns a human-readable version of this ID3 tree
func (t *ID3DecisionTree) String() string {
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
}