2014-05-14 14:00:22 +01:00
package trees
import (
2014-05-17 17:28:51 +01:00
"bytes"
2014-05-14 14:00:22 +01:00
"fmt"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
2014-08-22 09:33:42 +00:00
"github.com/sjwhitworth/golearn/evaluation"
2017-08-07 17:26:11 +01:00
"encoding/json"
2014-05-17 20:37:19 +01:00
"sort"
2014-05-14 14:00:22 +01:00
)
2014-10-26 12:07:38 +00:00
// NodeType determines whether a DecisionTreeNode is a leaf or not.
2014-05-14 14:00:22 +01:00
type NodeType int
const (
// LeafNode means there are no children
LeafNode NodeType = 1
// RuleNode means we should look at the next attribute value
RuleNode NodeType = 2
)
// RuleGenerator implementations analyse instances and determine
2014-10-26 12:07:38 +00:00
// the best value to split on.
2014-05-14 14:00:22 +01:00
type RuleGenerator interface {
2014-10-26 12:07:38 +00:00
GenerateSplitRule ( base . FixedDataGrid ) * DecisionTreeRule
2014-05-14 14:00:22 +01:00
}
2014-10-26 12:07:38 +00:00
// DecisionTreeRule represents the "decision" in "decision tree".
type DecisionTreeRule struct {
2017-08-07 17:26:11 +01:00
SplitAttr base . Attribute ` json:"split_attribute" `
SplitVal float64 ` json:"split_val" `
}
func ( d * DecisionTreeRule ) MarshalJSON ( ) ( [ ] byte , error ) {
ret := make ( map [ string ] interface { } )
marshaledSplitAttrRaw , err := d . SplitAttr . MarshalJSON ( )
if err != nil {
return nil , err
}
marshaledSplitAttr := make ( map [ string ] interface { } )
err = json . Unmarshal ( marshaledSplitAttrRaw , & marshaledSplitAttr )
if err != nil {
panic ( err )
}
ret [ "split_attribute" ] = marshaledSplitAttr
ret [ "split_val" ] = d . SplitVal
return json . Marshal ( ret )
}
func ( d * DecisionTreeRule ) unmarshalJSON ( data [ ] byte ) error {
var jsonMap map [ string ] interface { }
err := json . Unmarshal ( data , & jsonMap )
if err != nil {
return err
}
if splitVal , ok := jsonMap [ "split_val" ] ; ok {
d . SplitVal = splitVal . ( float64 )
}
split := jsonMap [ "split_attribute" ]
splitBytes , err := json . Marshal ( split )
if err != nil {
panic ( err )
}
d . SplitAttr , err = base . DeserializeAttribute ( splitBytes )
if err != nil {
return err
}
if d . SplitAttr == nil {
panic ( "Should not be nil" )
return fmt . Errorf ( "base.DeserializeAttribute returned nil" )
}
return nil
}
func ( d * DecisionTreeRule ) UnmarshalJSON ( data [ ] byte ) error {
ret := d . unmarshalJSON ( data )
return ret
2014-10-26 12:07:38 +00:00
}
// String prints a human-readable summary of this thing.
func ( d * DecisionTreeRule ) String ( ) string {
2017-08-07 17:26:11 +01:00
if ( d . SplitAttr == nil ) {
return fmt . Sprintf ( "INVALID:DecisionTreeRule(SplitAttr is nil)" )
}
2014-10-26 12:07:38 +00:00
if _ , ok := d . SplitAttr . ( * base . FloatAttribute ) ; ok {
return fmt . Sprintf ( "DecisionTreeRule(%s <= %f)" , d . SplitAttr . GetName ( ) , d . SplitVal )
}
return fmt . Sprintf ( "DecisionTreeRule(%s)" , d . SplitAttr . GetName ( ) )
}
// DecisionTreeNode represents a given portion of a decision tree.
2014-05-14 14:00:22 +01:00
type DecisionTreeNode struct {
2017-08-07 17:26:11 +01:00
Type NodeType ` json:"node_type" `
Children map [ string ] * DecisionTreeNode ` json:"children" `
ClassDist map [ string ] int ` json:"class_dist" `
Class string ` json:"class_string" `
ClassAttr base . Attribute ` json:"class_attribute" `
SplitRule * DecisionTreeRule ` json:"decision_tree_rule" `
2014-08-02 16:22:15 +01:00
}
func getClassAttr ( from base . FixedDataGrid ) base . Attribute {
allClassAttrs := from . AllClassAttributes ( )
return allClassAttrs [ 0 ]
2014-05-14 14:00:22 +01:00
}
2017-08-07 17:26:11 +01:00
// MarshalJSON returns a JSON representation of this Attribute
// for serialisation.
func ( d * DecisionTreeNode ) MarshalJSON ( ) ( [ ] byte , error ) {
ret := map [ string ] interface { } {
"type" : d . Type ,
"class_dist" : d . ClassDist ,
"class" : d . Class ,
}
if d . SplitRule != nil && d . SplitRule . SplitAttr != nil {
rawDRule , err := d . SplitRule . MarshalJSON ( )
if err != nil {
return nil , err
}
var dRule map [ string ] interface { }
err = json . Unmarshal ( rawDRule , & dRule )
if err != nil {
panic ( err )
}
ret [ "split_rule" ] = dRule
}
rawClassAttr , err := d . ClassAttr . MarshalJSON ( )
if err != nil {
return nil , err
}
var classAttr map [ string ] interface { }
err = json . Unmarshal ( rawClassAttr , & classAttr )
ret [ "class_attr" ] = classAttr
if len ( d . Children ) > 0 {
children := make ( map [ string ] interface { } )
for k := range d . Children {
cur , err := d . Children [ k ] . MarshalJSON ( )
if err != nil {
return nil , err
}
var child map [ string ] interface { }
err = json . Unmarshal ( cur , & child )
if err != nil {
panic ( err )
}
children [ k ] = child
}
ret [ "children" ] = children
}
return json . Marshal ( ret )
}
// UnmarshalJSON reads a JSON representation of this Attribute.
func ( d * DecisionTreeNode ) UnmarshalJSON ( data [ ] byte ) error {
jsonMap := make ( map [ string ] interface { } )
err := json . Unmarshal ( data , & jsonMap )
if err != nil {
return err
}
rawType := int ( jsonMap [ "type" ] . ( float64 ) )
if rawType == 1 {
d . Type = LeafNode
} else if rawType == 2 {
d . Type = RuleNode
} else {
return fmt . Errorf ( "Unknown nodeType: %d" , rawType )
}
//d.Type = NodeType(int(jsonMap["type"].(float64)))
// Convert the class distribution back
classDist := jsonMap [ "class_dist" ] . ( map [ string ] interface { } )
d . ClassDist = make ( map [ string ] int )
for k := range classDist {
d . ClassDist [ k ] = int ( classDist [ k ] . ( float64 ) )
}
d . Class = jsonMap [ "class" ] . ( string )
//
// Decode the class attribute
//
// Temporarily re-marshal this field back to bytes
rawClassAttr := jsonMap [ "class_attr" ]
rawClassAttrBytes , err := json . Marshal ( rawClassAttr )
if err != nil {
return err
}
classAttr , err := base . DeserializeAttribute ( rawClassAttrBytes )
if err != nil {
return err
}
d . ClassAttr = classAttr
d . SplitRule = nil
if splitRule , ok := jsonMap [ "split_rule" ] ; ok {
d . SplitRule = & DecisionTreeRule { }
splitRuleBytes , err := json . Marshal ( splitRule )
if err != nil {
panic ( err )
}
err = d . SplitRule . UnmarshalJSON ( splitRuleBytes )
if err != nil {
return err
}
d . Children = make ( map [ string ] * DecisionTreeNode )
childMap := jsonMap [ "children" ] . ( map [ string ] interface { } )
for i := range childMap {
cur := & DecisionTreeNode { }
childBytes , err := json . Marshal ( childMap [ i ] )
if err != nil {
panic ( err )
}
err = cur . UnmarshalJSON ( childBytes )
if err != nil {
return err
}
d . Children [ i ] = cur
}
}
return nil
}
// Save sends the classification tree to an output file
func ( d * DecisionTreeNode ) Save ( filePath string ) error {
metadata := base . ClassifierMetadataV1 {
FormatVersion : 1 ,
ClassifierName : "test" ,
ClassifierVersion : "1" ,
ClassifierMetadata : nil ,
}
serializer , err := base . CreateSerializedClassifierStub ( filePath , metadata )
if err != nil {
return err
}
2017-08-08 12:37:57 +01:00
err = d . SaveWithPrefix ( serializer , "" )
serializer . Close ( )
return err
}
func ( d * DecisionTreeNode ) SaveWithPrefix ( serializer * base . ClassifierSerializer , prefix string ) error {
2017-08-07 17:26:11 +01:00
b , err := json . Marshal ( d )
if err != nil {
return err
}
2017-08-08 12:37:57 +01:00
err = serializer . WriteBytesForKey ( fmt . Sprintf ( "%s%s" , prefix , "tree" ) , b )
2017-08-07 17:26:11 +01:00
if err != nil {
return err
}
return nil
}
// Load reads from the classifier from an output file
func ( d * DecisionTreeNode ) Load ( filePath string ) error {
reader , err := base . ReadSerializedClassifierStub ( filePath )
if err != nil {
return err
}
2017-08-08 12:37:57 +01:00
err = d . LoadWithPrefix ( reader , "" )
reader . Close ( )
return err
}
2017-08-07 17:26:11 +01:00
2017-08-08 12:37:57 +01:00
func ( d * DecisionTreeNode ) LoadWithPrefix ( reader * base . ClassifierDeserializer , prefix string ) error {
b , err := reader . GetBytesForKey ( fmt . Sprintf ( "%s%s" , prefix , "tree" ) )
2017-08-07 17:26:11 +01:00
if err != nil {
return err
}
err = json . Unmarshal ( b , d )
if err != nil {
return err
}
return nil
}
2014-05-17 17:28:51 +01:00
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
2014-08-02 16:22:15 +01:00
func InferID3Tree ( from base . FixedDataGrid , with RuleGenerator ) * DecisionTreeNode {
2014-05-14 14:00:22 +01:00
// Count the number of classes at this node
2014-08-02 16:22:15 +01:00
classes := base . GetClassDistribution ( from )
2014-05-14 14:00:22 +01:00
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len ( classes ) == 1 {
maxClass := ""
for i := range classes {
maxClass = i
}
ret := & DecisionTreeNode {
LeafNode ,
nil ,
classes ,
maxClass ,
2014-08-02 16:22:15 +01:00
getClassAttr ( from ) ,
2014-10-26 12:07:38 +00:00
& DecisionTreeRule { nil , 0.0 } ,
2014-05-14 14:00:22 +01:00
}
return ret
}
// Only have the class attribute
maxVal := 0
maxClass := ""
for i := range classes {
if classes [ i ] > maxVal {
maxClass = i
maxVal = classes [ i ]
}
}
2014-05-17 20:37:19 +01:00
// If there are no more Attributes left to split on,
2014-05-14 14:00:22 +01:00
// return a DecisionTreeLeaf with the majority class
2014-08-02 16:22:15 +01:00
cols , _ := from . Size ( )
if cols == 2 {
2014-05-14 14:00:22 +01:00
ret := & DecisionTreeNode {
LeafNode ,
nil ,
classes ,
maxClass ,
2014-08-02 16:22:15 +01:00
getClassAttr ( from ) ,
2014-10-26 12:07:38 +00:00
& DecisionTreeRule { nil , 0.0 } ,
2014-05-14 14:00:22 +01:00
}
return ret
}
2014-08-02 16:22:15 +01:00
// Generate a return structure
2014-05-14 14:00:22 +01:00
ret := & DecisionTreeNode {
RuleNode ,
2014-05-17 20:37:19 +01:00
nil ,
2014-05-14 14:00:22 +01:00
classes ,
maxClass ,
2014-08-02 16:22:15 +01:00
getClassAttr ( from ) ,
2014-10-26 12:07:38 +00:00
nil ,
2014-05-14 14:00:22 +01:00
}
2014-10-26 12:07:38 +00:00
// Generate the splitting rule
splitRule := with . GenerateSplitRule ( from )
2016-06-28 14:36:48 -04:00
if splitRule == nil || splitRule . SplitAttr == nil {
2014-05-17 20:37:19 +01:00
// Can't determine, just return what we have
return ret
}
2014-10-26 12:07:38 +00:00
2014-05-14 14:00:22 +01:00
// Split the attributes based on this attribute's value
2014-10-26 12:07:38 +00:00
var splitInstances map [ string ] base . FixedDataGrid
if _ , ok := splitRule . SplitAttr . ( * base . FloatAttribute ) ; ok {
splitInstances = base . DecomposeOnNumericAttributeThreshold ( from ,
splitRule . SplitAttr , splitRule . SplitVal )
} else {
splitInstances = base . DecomposeOnAttributeValues ( from , splitRule . SplitAttr )
}
2014-05-14 14:00:22 +01:00
// Create new children from these attributes
2014-05-17 20:37:19 +01:00
ret . Children = make ( map [ string ] * DecisionTreeNode )
2014-05-14 14:00:22 +01:00
for k := range splitInstances {
newInstances := splitInstances [ k ]
2014-05-17 17:28:51 +01:00
ret . Children [ k ] = InferID3Tree ( newInstances , with )
2014-05-14 14:00:22 +01:00
}
2014-10-26 12:07:38 +00:00
ret . SplitRule = splitRule
2014-05-14 14:00:22 +01:00
return ret
}
2014-05-19 12:59:11 +01:00
// getNestedString returns the contents of node d
// prefixed by level number of tags (also prints children)
2014-05-17 17:28:51 +01:00
func ( d * DecisionTreeNode ) getNestedString ( level int ) string {
buf := bytes . NewBuffer ( nil )
tmp := bytes . NewBuffer ( nil )
for i := 0 ; i < level ; i ++ {
tmp . WriteString ( "\t" )
}
buf . WriteString ( tmp . String ( ) )
if d . Children == nil {
buf . WriteString ( fmt . Sprintf ( "Leaf(%s)" , d . Class ) )
} else {
2014-10-26 12:07:38 +00:00
var keys [ ] string
buf . WriteString ( fmt . Sprintf ( "Rule(%s)" , d . SplitRule ) )
2014-05-14 14:00:22 +01:00
for k := range d . Children {
2014-05-17 20:37:19 +01:00
keys = append ( keys , k )
}
sort . Strings ( keys )
for _ , k := range keys {
2014-05-17 17:28:51 +01:00
buf . WriteString ( "\n" )
buf . WriteString ( tmp . String ( ) )
buf . WriteString ( "\t" )
buf . WriteString ( k )
buf . WriteString ( "\n" )
buf . WriteString ( d . Children [ k ] . getNestedString ( level + 1 ) )
2014-05-14 14:00:22 +01:00
}
}
2014-05-17 17:28:51 +01:00
return buf . String ( )
}
2014-05-14 14:00:22 +01:00
2014-05-17 17:28:51 +01:00
// String returns a human-readable representation of a given node
// and it's children
func ( d * DecisionTreeNode ) String ( ) string {
return d . getNestedString ( 0 )
2014-05-14 14:00:22 +01:00
}
2014-05-19 12:59:11 +01:00
// computeAccuracy is a helper method for Prune()
2014-08-02 16:22:15 +01:00
func computeAccuracy ( predictions base . FixedDataGrid , from base . FixedDataGrid ) float64 {
2014-08-22 09:33:42 +00:00
cf , _ := evaluation . GetConfusionMatrix ( from , predictions )
return evaluation . GetAccuracy ( cf )
2014-05-17 18:06:01 +01:00
}
// Prune eliminates branches which hurt accuracy
2014-08-02 16:22:15 +01:00
func ( d * DecisionTreeNode ) Prune ( using base . FixedDataGrid ) {
2014-05-17 18:06:01 +01:00
// If you're a leaf, you're already pruned
if d . Children == nil {
return
2014-07-18 13:20:46 +03:00
}
2014-10-26 12:07:38 +00:00
if d . SplitRule == nil {
2014-07-18 13:20:46 +03:00
return
}
// Recursively prune children of this node
2014-10-26 12:07:38 +00:00
sub := base . DecomposeOnAttributeValues ( using , d . SplitRule . SplitAttr )
2014-07-18 13:20:46 +03:00
for k := range d . Children {
if sub [ k ] == nil {
continue
2014-05-17 18:06:01 +01:00
}
2014-08-02 16:22:15 +01:00
subH , subV := sub [ k ] . Size ( )
if subH == 0 || subV == 0 {
continue
}
2014-07-18 13:20:46 +03:00
d . Children [ k ] . Prune ( sub [ k ] )
2014-05-17 18:06:01 +01:00
}
// Get a baseline accuracy
2014-08-20 07:16:11 +00:00
predictions , _ := d . Predict ( using )
baselineAccuracy := computeAccuracy ( predictions , using )
2014-05-17 18:06:01 +01:00
// Speculatively remove the children and re-evaluate
tmpChildren := d . Children
d . Children = nil
2014-08-20 07:16:11 +00:00
predictions , _ = d . Predict ( using )
newAccuracy := computeAccuracy ( predictions , using )
2014-05-17 18:06:01 +01:00
// Keep the children removed if better, else restore
if newAccuracy < baselineAccuracy {
d . Children = tmpChildren
}
}
// Predict outputs a base.Instances containing predictions from this tree
2014-08-20 07:16:11 +00:00
func ( d * DecisionTreeNode ) Predict ( what base . FixedDataGrid ) ( base . FixedDataGrid , error ) {
2014-08-02 16:22:15 +01:00
predictions := base . GeneratePredictionVector ( what )
classAttr := getClassAttr ( predictions )
classAttrSpec , err := predictions . GetAttribute ( classAttr )
if err != nil {
panic ( err )
}
predAttrs := base . AttributeDifferenceReferences ( what . AllAttributes ( ) , predictions . AllClassAttributes ( ) )
2014-08-03 12:31:26 +01:00
predAttrSpecs := base . ResolveAttributes ( what , predAttrs )
2014-08-02 16:22:15 +01:00
what . MapOverRows ( predAttrSpecs , func ( row [ ] [ ] byte , rowNo int ) ( bool , error ) {
2014-05-14 14:00:22 +01:00
cur := d
2014-05-17 20:37:19 +01:00
for {
2014-05-14 14:00:22 +01:00
if cur . Children == nil {
2014-08-02 16:22:15 +01:00
predictions . Set ( classAttrSpec , rowNo , classAttr . GetSysValFromString ( cur . Class ) )
2014-05-17 20:37:19 +01:00
break
2014-05-14 14:00:22 +01:00
} else {
2014-10-26 12:07:38 +00:00
splitVal := cur . SplitRule . SplitVal
at := cur . SplitRule . SplitAttr
2014-08-02 16:22:15 +01:00
ats , err := what . GetAttribute ( at )
if err != nil {
2014-10-26 12:07:38 +00:00
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic ( err )
2014-05-18 11:49:35 +01:00
}
2014-08-02 16:22:15 +01:00
2014-10-26 12:07:38 +00:00
var classVar string
if _ , ok := ats . GetAttribute ( ) . ( * base . FloatAttribute ) ; ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base . UnpackBytesToFloat ( what . Get ( ats , rowNo ) )
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats . GetAttribute ( ) . GetStringFromSysVal ( what . Get ( ats , rowNo ) )
}
2014-05-14 14:00:22 +01:00
if next , ok := cur . Children [ classVar ] ; ok {
cur = next
} else {
2014-10-26 12:07:38 +00:00
// Suspicious of this
2014-05-17 20:37:19 +01:00
var bestChild string
for c := range cur . Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur . Children [ bestChild ]
2014-05-14 14:00:22 +01:00
}
}
}
2014-08-02 16:22:15 +01:00
return true , nil
} )
2014-08-20 07:16:11 +00:00
return predictions , nil
2014-05-14 14:00:22 +01:00
}
2014-05-17 21:45:26 +01:00
2017-07-17 14:48:38 +03:00
type ClassProba struct {
2017-07-17 15:35:35 +03:00
Probability float64
ClassValue string
2017-07-17 14:48:38 +03:00
}
type ClassesProba [ ] ClassProba
func ( o ClassesProba ) Len ( ) int {
return len ( o )
}
func ( o ClassesProba ) Swap ( i , j int ) {
o [ i ] , o [ j ] = o [ j ] , o [ i ]
}
func ( o ClassesProba ) Less ( i , j int ) bool {
2017-07-17 16:01:49 +03:00
return o [ i ] . Probability > o [ j ] . Probability
2017-07-17 14:48:38 +03:00
}
// Predict class probabilities of the input samples what, returns a sorted array (by probability) of classes, and another array representing it's probabilities
func ( t * ID3DecisionTree ) PredictProba ( what base . FixedDataGrid ) ( ClassesProba , error ) {
d := t . Root
predictions := base . GeneratePredictionVector ( what )
predAttrs := base . AttributeDifferenceReferences ( what . AllAttributes ( ) , predictions . AllClassAttributes ( ) )
predAttrSpecs := base . ResolveAttributes ( what , predAttrs )
2017-07-17 15:35:35 +03:00
_ , rowCount := what . Size ( )
if rowCount > 1 {
panic ( "PredictProba supports only 1 row predictions" )
}
2017-07-17 14:48:38 +03:00
var results ClassesProba
what . MapOverRows ( predAttrSpecs , func ( row [ ] [ ] byte , rowNo int ) ( bool , error ) {
cur := d
for {
if cur . Children == nil {
totalDist := 0
for _ , dist := range cur . ClassDist {
totalDist += dist
}
for class , dist := range cur . ClassDist {
2017-07-17 15:35:35 +03:00
classProba := ClassProba { ClassValue : class , Probability : float64 ( float64 ( dist ) / float64 ( totalDist ) ) }
2017-07-17 14:48:38 +03:00
results = append ( results , classProba )
}
sort . Sort ( results )
break
} else {
splitVal := cur . SplitRule . SplitVal
at := cur . SplitRule . SplitAttr
ats , err := what . GetAttribute ( at )
if err != nil {
//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
//break
panic ( err )
}
var classVar string
if _ , ok := ats . GetAttribute ( ) . ( * base . FloatAttribute ) ; ok {
// If it's a numeric Attribute (e.g. FloatAttribute) check that
// the value of the current node is greater than the old one
classVal := base . UnpackBytesToFloat ( what . Get ( ats , rowNo ) )
if classVal > splitVal {
classVar = "1"
} else {
classVar = "0"
}
} else {
classVar = ats . GetAttribute ( ) . GetStringFromSysVal ( what . Get ( ats , rowNo ) )
}
if next , ok := cur . Children [ classVar ] ; ok {
cur = next
} else {
// Suspicious of this
var bestChild string
for c := range cur . Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur . Children [ bestChild ]
}
}
}
return true , nil
} )
return results , nil
}
2014-05-17 21:45:26 +01:00
//
// ID3 Tree type
//
// ID3DecisionTree represents an ID3-based decision tree
// using the Information Gain metric to select which attributes
// to split on at each node.
type ID3DecisionTree struct {
base . BaseClassifier
Root * DecisionTreeNode
PruneSplit float64
2014-10-26 12:07:38 +00:00
Rule RuleGenerator
2014-05-17 21:45:26 +01:00
}
2014-07-18 13:48:28 +03:00
// NewID3DecisionTree returns a new ID3DecisionTree with the specified test-prune
2014-10-26 12:07:38 +00:00
// ratio and InformationGain as the rule generator.
// If the ratio is less than 0.001, the tree isn't pruned.
2014-05-17 21:45:26 +01:00
func NewID3DecisionTree ( prune float64 ) * ID3DecisionTree {
return & ID3DecisionTree {
base . BaseClassifier { } ,
nil ,
prune ,
2014-10-26 12:07:38 +00:00
new ( InformationGainRuleGenerator ) ,
}
}
// NewID3DecisionTreeFromRule returns a new ID3DecisionTree with the specified test-prun
// ratio and the given rule gnereator.
func NewID3DecisionTreeFromRule ( prune float64 , rule RuleGenerator ) * ID3DecisionTree {
return & ID3DecisionTree {
base . BaseClassifier { } ,
nil ,
prune ,
rule ,
2014-05-17 21:45:26 +01:00
}
}
// Fit builds the ID3 decision tree
2014-08-20 07:16:11 +00:00
func ( t * ID3DecisionTree ) Fit ( on base . FixedDataGrid ) error {
2014-05-17 21:45:26 +01:00
if t . PruneSplit > 0.001 {
2014-06-06 20:30:24 +02:00
trainData , testData := base . InstancesTrainTestSplit ( on , t . PruneSplit )
2014-10-26 12:07:38 +00:00
t . Root = InferID3Tree ( trainData , t . Rule )
2014-06-06 20:30:24 +02:00
t . Root . Prune ( testData )
2014-05-17 21:45:26 +01:00
} else {
2014-10-26 12:07:38 +00:00
t . Root = InferID3Tree ( on , t . Rule )
2014-05-17 21:45:26 +01:00
}
2014-08-20 07:16:11 +00:00
return nil
2014-05-17 21:45:26 +01:00
}
// Predict outputs predictions from the ID3 decision tree
2014-08-20 07:16:11 +00:00
func ( t * ID3DecisionTree ) Predict ( what base . FixedDataGrid ) ( base . FixedDataGrid , error ) {
2014-05-17 21:45:26 +01:00
return t . Root . Predict ( what )
}
2014-05-19 12:59:11 +01:00
// String returns a human-readable version of this ID3 tree
2014-05-17 21:45:26 +01:00
func ( t * ID3DecisionTree ) String ( ) string {
return fmt . Sprintf ( "ID3DecisionTree(%s\n)" , t . Root )
}