mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
416 lines
12 KiB
Go
416 lines
12 KiB
Go
package trees
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
const (
|
|
GINI string = "gini"
|
|
ENTROPY string = "entropy"
|
|
)
|
|
|
|
// CNode is Node struct for Decision Tree Classifier.
|
|
// It holds the information for each split (which feature to use, what threshold, and which label to assign for each side of the split)
|
|
type classifierNode struct {
|
|
Left *classifierNode
|
|
Right *classifierNode
|
|
Threshold float64
|
|
Feature int64
|
|
LeftLabel int64
|
|
RightLabel int64
|
|
isNodeNeeded bool
|
|
}
|
|
|
|
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
|
// It contains the rootNode, as well as all of the hyperparameters chosen by the user.
|
|
// It also keeps track of all splits done at the tree level.
|
|
type CARTDecisionTreeClassifier struct {
|
|
RootNode *classifierNode
|
|
criterion string
|
|
maxDepth int64
|
|
labels []int64
|
|
triedSplits [][]float64
|
|
}
|
|
|
|
// Calculate Gini Impurity of Target Labels
|
|
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
|
nInstances := len(y)
|
|
gini := 0.0
|
|
maxLabelCount := 0
|
|
var maxLabel int64 = 0
|
|
for label := range labels {
|
|
numLabel := 0
|
|
for target := range y {
|
|
if y[target] == labels[label] {
|
|
numLabel++
|
|
}
|
|
}
|
|
p := float64(numLabel) / float64(nInstances)
|
|
gini += p * (1 - p)
|
|
if numLabel > maxLabelCount {
|
|
maxLabel = labels[label]
|
|
maxLabelCount = numLabel
|
|
}
|
|
}
|
|
return gini, maxLabel
|
|
}
|
|
|
|
// Calculate Entropy loss of Target Labels
|
|
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
|
nInstances := len(y)
|
|
entropy := 0.0
|
|
maxLabelCount := 0
|
|
var maxLabel int64 = 0
|
|
for label := range labels {
|
|
numLabel := 0
|
|
for target := range y {
|
|
if y[target] == labels[label] {
|
|
numLabel++
|
|
}
|
|
}
|
|
p := float64(numLabel) / float64(nInstances)
|
|
|
|
logP := math.Log2(p)
|
|
if p == 0 {
|
|
logP = 0
|
|
}
|
|
entropy += -p * logP
|
|
if numLabel > maxLabelCount {
|
|
maxLabel = labels[label]
|
|
maxLabelCount = numLabel
|
|
}
|
|
}
|
|
return entropy, maxLabel
|
|
}
|
|
|
|
func calculateClassificationLoss(y []int64, labels []int64, criterion string) (float64, int64) {
|
|
if criterion == GINI {
|
|
return computeGiniImpurityAndModeLabel(y, labels)
|
|
} else if criterion == ENTROPY {
|
|
return computeEntropyAndModeLabel(y, labels)
|
|
} else {
|
|
panic("Invalid impurity function, choose from GINI or ENTROPY")
|
|
}
|
|
}
|
|
|
|
// Split the data into left node and right node based on feature and threshold
|
|
func classifierCreateSplit(data [][]float64, feature int64, y []int64, threshold float64) ([][]float64, [][]float64, []int64, []int64) {
|
|
var left [][]float64
|
|
var right [][]float64
|
|
var lefty []int64
|
|
var righty []int64
|
|
|
|
for i := range data {
|
|
example := data[i]
|
|
if example[feature] < threshold {
|
|
left = append(left, example)
|
|
lefty = append(lefty, y[i])
|
|
} else {
|
|
right = append(right, example)
|
|
righty = append(righty, y[i])
|
|
}
|
|
}
|
|
|
|
return left, right, lefty, righty
|
|
}
|
|
|
|
// Function to Create New Decision Tree Classifier.
|
|
// It assigns all of the hyperparameters by user into the tree attributes.
|
|
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
|
var tree CARTDecisionTreeClassifier
|
|
tree.criterion = strings.ToLower(criterion)
|
|
tree.maxDepth = maxDepth
|
|
tree.labels = labels
|
|
|
|
return &tree
|
|
}
|
|
|
|
// Reorder the data by feature being considered. Optimizes code by reducing the number of times we have to loop over data for splitting
|
|
func classifierReOrderData(featureVal []float64, data [][]float64, y []int64) ([][]float64, []int64) {
|
|
s := NewSlice(featureVal)
|
|
sort.Sort(s)
|
|
|
|
indexes := s.Idx
|
|
|
|
var dataSorted [][]float64
|
|
var ySorted []int64
|
|
|
|
for _, index := range indexes {
|
|
dataSorted = append(dataSorted, data[index])
|
|
ySorted = append(ySorted, y[index])
|
|
}
|
|
|
|
return dataSorted, ySorted
|
|
}
|
|
|
|
// Update the left and right side of the split based on the threshold.
|
|
func classifierUpdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []int64, feature int64, threshold float64) ([][]float64, []int64, [][]float64, []int64) {
|
|
|
|
for right[0][feature] < threshold {
|
|
left = append(left, right[0])
|
|
right = right[1:]
|
|
lefty = append(lefty, righty[0])
|
|
righty = righty[1:]
|
|
}
|
|
|
|
return left, lefty, right, righty
|
|
}
|
|
|
|
// Fit - Creates an Emppty Root Node2
|
|
// Trains the tree by calling recursive function classifierBestSplit
|
|
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
|
var emptyNode classifierNode
|
|
|
|
data := convertInstancesToProblemVec(X)
|
|
y := classifierConvertInstancesToLabelVec(X)
|
|
emptyNode = classifierBestSplit(*tree, data, y, tree.labels, emptyNode, tree.criterion, tree.maxDepth, 0)
|
|
|
|
tree.RootNode = &emptyNode
|
|
}
|
|
|
|
// Iterativly find and record the best split
|
|
// Stop If depth reaches maxDepth or nodes are pure
|
|
func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode classifierNode, criterion string, maxDepth int64, depth int64) classifierNode {
|
|
|
|
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
|
depth++
|
|
|
|
if maxDepth != -1 && depth > maxDepth {
|
|
return upperNode
|
|
}
|
|
|
|
numFeatures := len(data[0])
|
|
var bestGini, origGini float64
|
|
|
|
// Calculate loss based on Criterion Specified by user
|
|
origGini, upperNode.LeftLabel = calculateClassificationLoss(y, labels, criterion)
|
|
|
|
bestGini = origGini
|
|
|
|
bestLeft, bestRight, bestLefty, bestRighty := data, data, y, y
|
|
|
|
numData := len(data)
|
|
|
|
bestLeftGini, bestRightGini := bestGini, bestGini
|
|
|
|
upperNode.isNodeNeeded = true
|
|
|
|
var leftN, rightN classifierNode
|
|
|
|
// Iterate over all features
|
|
for i := 0; i < numFeatures; i++ {
|
|
|
|
featureVal := getFeature(data, int64(i))
|
|
unique := findUnique(featureVal)
|
|
sort.Float64s(unique)
|
|
|
|
sortData, sortY := classifierReOrderData(featureVal, data, y)
|
|
|
|
firstTime := true
|
|
|
|
var left, right [][]float64
|
|
var lefty, righty []int64
|
|
// Iterate over all possible thresholds for that feature
|
|
for j := 0; j < len(unique)-1; j++ {
|
|
|
|
threshold := (unique[j] + unique[j+1]) / 2
|
|
// Ensure that same split has not been made before
|
|
if validate(tree.triedSplits, int64(i), threshold) {
|
|
// We need to split data from fresh when considering new feature for the first time.
|
|
// Otherwise, we need to update the split by moving data points from left to right.
|
|
if firstTime {
|
|
left, right, lefty, righty = classifierCreateSplit(sortData, int64(i), sortY, threshold)
|
|
firstTime = false
|
|
} else {
|
|
left, lefty, right, righty = classifierUpdateSplit(left, lefty, right, righty, int64(i), threshold)
|
|
}
|
|
|
|
var leftGini, rightGini float64
|
|
var leftLabels, rightLabels int64
|
|
|
|
leftGini, leftLabels = calculateClassificationLoss(lefty, labels, criterion)
|
|
rightGini, rightLabels = calculateClassificationLoss(righty, labels, criterion)
|
|
|
|
// Calculate weighted gini impurity of child nodes
|
|
subGini := (leftGini * float64(len(left)) / float64(numData)) + (rightGini * float64(len(right)) / float64(numData))
|
|
|
|
// If we find a split that reduces impurity
|
|
if subGini < bestGini {
|
|
bestGini = subGini
|
|
|
|
bestLeft, bestRight = left, right
|
|
|
|
bestLefty, bestRighty = lefty, righty
|
|
|
|
upperNode.Threshold, upperNode.Feature = threshold, int64(i)
|
|
|
|
upperNode.LeftLabel, upperNode.RightLabel = leftLabels, rightLabels
|
|
|
|
bestLeftGini, bestRightGini = leftGini, rightGini
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// If no split was found, we don't want to use this node, so we will flag it
|
|
if bestGini == origGini {
|
|
upperNode.isNodeNeeded = false
|
|
return upperNode
|
|
}
|
|
// Until nodes are not pure
|
|
if bestGini > 0 {
|
|
|
|
// If left node is pure, no need to split on left side again
|
|
if bestLeftGini > 0 {
|
|
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
|
// Recursive splitting logic
|
|
leftN = classifierBestSplit(tree, bestLeft, bestLefty, labels, leftN, criterion, maxDepth, depth)
|
|
if leftN.isNodeNeeded == true {
|
|
upperNode.Left = &leftN
|
|
}
|
|
|
|
}
|
|
// If right node is pure, no need to split on right side again
|
|
if bestRightGini > 0 {
|
|
tree.triedSplits = append(tree.triedSplits, []float64{float64(upperNode.Feature), upperNode.Threshold})
|
|
// Recursive splitting logic
|
|
rightN = classifierBestSplit(tree, bestRight, bestRighty, labels, rightN, criterion, maxDepth, depth)
|
|
if rightN.isNodeNeeded == true {
|
|
upperNode.Right = &rightN
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
// Return the node - contains all information regarding feature and threshold.
|
|
return upperNode
|
|
}
|
|
|
|
// String : this function prints out entire tree for visualization.
|
|
// Calls a recursive function to print the tree - classifierPrintTreeFromNode
|
|
func (tree *CARTDecisionTreeClassifier) String() string {
|
|
rootNode := *tree.RootNode
|
|
return classifierPrintTreeFromNode(rootNode, "")
|
|
}
|
|
|
|
func classifierPrintTreeFromNode(tree classifierNode, spacing string) string {
|
|
returnString := ""
|
|
returnString += spacing + "Feature "
|
|
returnString += strconv.FormatInt(tree.Feature, 10)
|
|
returnString += " < "
|
|
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
|
returnString += "\n"
|
|
|
|
if tree.Left == nil {
|
|
returnString += spacing + "---> True" + "\n"
|
|
returnString += " " + spacing + "PREDICT "
|
|
returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
|
|
}
|
|
if tree.Right == nil {
|
|
returnString += spacing + "---> False" + "\n"
|
|
returnString += " " + spacing + "PREDICT "
|
|
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
|
|
}
|
|
|
|
if tree.Left != nil {
|
|
returnString += spacing + "---> True" + "\n"
|
|
returnString += classifierPrintTreeFromNode(*tree.Left, spacing+" ")
|
|
}
|
|
|
|
if tree.Right != nil {
|
|
returnString += spacing + "---> False" + "\n"
|
|
returnString += classifierPrintTreeFromNode(*tree.Right, spacing+" ")
|
|
}
|
|
|
|
return returnString
|
|
}
|
|
|
|
// Predict a single data point by traversing the entire tree
|
|
// Uses recursive logic to navigate the tree.
|
|
func classifierPredictSingle(tree classifierNode, instance []float64) int64 {
|
|
if instance[tree.Feature] < tree.Threshold {
|
|
if tree.Left == nil {
|
|
return tree.LeftLabel
|
|
} else {
|
|
return classifierPredictSingle(*tree.Left, instance)
|
|
}
|
|
} else {
|
|
if tree.Right == nil {
|
|
return tree.RightLabel
|
|
} else {
|
|
return classifierPredictSingle(*tree.Right, instance)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Given test data, return predictions for every datapoint. calls classifierPredictFromNode
|
|
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
|
root := *tree.RootNode
|
|
test := convertInstancesToProblemVec(X_test)
|
|
return classifierPredictFromNode(root, test)
|
|
}
|
|
|
|
// This function uses the rootnode from Predict.
|
|
// It iterates through every data point and calls the recursive function to give predictions and then summarizes them.
|
|
func classifierPredictFromNode(tree classifierNode, test [][]float64) []int64 {
|
|
var preds []int64
|
|
for i := range test {
|
|
iPred := classifierPredictSingle(tree, test[i])
|
|
preds = append(preds, iPred)
|
|
}
|
|
return preds
|
|
}
|
|
|
|
// Given Test data and label, return the accuracy of the classifier.
|
|
// First it retreives predictions from the data, then compares for accuracy.
|
|
// Calls classifierEvaluateFromNode
|
|
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
|
|
rootNode := *tree.RootNode
|
|
xTest := convertInstancesToProblemVec(test)
|
|
yTest := classifierConvertInstancesToLabelVec(test)
|
|
return classifierEvaluateFromNode(rootNode, xTest, yTest)
|
|
}
|
|
|
|
// Retrieve predictions and then calculate accuracy.
|
|
func classifierEvaluateFromNode(tree classifierNode, xTest [][]float64, yTest []int64) float64 {
|
|
preds := classifierPredictFromNode(tree, xTest)
|
|
accuracy := 0.0
|
|
for i := range preds {
|
|
if preds[i] == yTest[i] {
|
|
accuracy++
|
|
}
|
|
}
|
|
accuracy /= float64(len(yTest))
|
|
return accuracy
|
|
}
|
|
|
|
// Helper function to convert base.FixedDataGrid into required format. Called in Fit, Predict
|
|
func classifierConvertInstancesToLabelVec(X base.FixedDataGrid) []int64 {
|
|
// Get the class Attributes
|
|
classAttrs := X.AllClassAttributes()
|
|
// Only support 1 class Attribute
|
|
if len(classAttrs) != 1 {
|
|
panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
|
|
}
|
|
// ClassAttribute must be numeric
|
|
if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
|
|
panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
|
|
}
|
|
// Allocate return structure
|
|
_, rows := X.Size()
|
|
// labelVec := make([]float64, rows)
|
|
labelVec := make([]int64, rows)
|
|
// Resolve class Attribute specification
|
|
classAttrSpecs := base.ResolveAttributes(X, classAttrs)
|
|
X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
|
|
labelVec[rowNo] = int64(base.UnpackBytesToFloat(row[0]))
|
|
return true, nil
|
|
})
|
|
return labelVec
|
|
}
|