mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Adding Changes
This commit is contained in:
parent
08529c42cf
commit
c083759523
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
@ -23,8 +24,8 @@ type CNode struct {
|
||||
maxDepth int64
|
||||
}
|
||||
|
||||
// CTree: Tree struct for Decision Tree Classifier
|
||||
type CTree struct {
|
||||
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
||||
type CARTDecisionTreeClassifier struct {
|
||||
RootNode *CNode
|
||||
criterion string
|
||||
maxDepth int64
|
||||
@ -135,8 +136,8 @@ func cgetFeature(data [][]float64, feature int64) []float64 {
|
||||
}
|
||||
|
||||
// Function to Create New Decision Tree Classifier
|
||||
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CTree {
|
||||
var tree CTree
|
||||
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
||||
var tree CARTDecisionTreeClassifier
|
||||
tree.criterion = strings.ToLower(criterion)
|
||||
tree.maxDepth = maxDepth
|
||||
tree.labels = labels
|
||||
@ -210,7 +211,7 @@ func cupdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []i
|
||||
}
|
||||
|
||||
// Fit - Method visible to user to train tree
|
||||
func (tree *CTree) Fit(X base.FixedDataGrid) {
|
||||
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
||||
var emptyNode CNode
|
||||
|
||||
data := classifierConvertInstancesToProblemVec(X)
|
||||
@ -221,7 +222,7 @@ func (tree *CTree) Fit(X base.FixedDataGrid) {
|
||||
}
|
||||
|
||||
// Iterativly find and record the best split - recursive function
|
||||
func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
|
||||
func cbestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
|
||||
|
||||
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
||||
depth++
|
||||
@ -358,41 +359,43 @@ func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNo
|
||||
}
|
||||
|
||||
// PrintTree : this function prints out entire tree for visualization - visible to user
|
||||
func (tree *CTree) PrintTree() {
|
||||
func (tree *CARTDecisionTreeClassifier) String() string {
|
||||
rootNode := *tree.RootNode
|
||||
cprintTreeFromNode(rootNode, "")
|
||||
return cprintTreeFromNode(rootNode, "")
|
||||
}
|
||||
|
||||
// Tree struct has root node. That is used to print tree - invisible to user but called from PrintTree
|
||||
func cprintTreeFromNode(tree CNode, spacing string) float64 {
|
||||
|
||||
fmt.Print(spacing + "Feature ")
|
||||
fmt.Print(tree.Feature)
|
||||
fmt.Print(" < ")
|
||||
fmt.Println(tree.Threshold)
|
||||
func cprintTreeFromNode(tree CNode, spacing string) string {
|
||||
returnString := ""
|
||||
returnString += spacing + "Feature "
|
||||
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||
returnString += " < "
|
||||
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||
returnString += "\n"
|
||||
|
||||
if tree.Left == nil {
|
||||
fmt.Println(spacing + "---> True")
|
||||
fmt.Print(" " + spacing + "PREDICT ")
|
||||
fmt.Println(tree.LeftLabel)
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
|
||||
|
||||
}
|
||||
if tree.Right == nil {
|
||||
fmt.Println(spacing + "---> FALSE")
|
||||
fmt.Print(" " + spacing + "PREDICT ")
|
||||
fmt.Println(tree.RightLabel)
|
||||
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
|
||||
}
|
||||
|
||||
if tree.Left != nil {
|
||||
fmt.Println(spacing + "---> True")
|
||||
cprintTreeFromNode(*tree.Left, spacing+" ")
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += cprintTreeFromNode(*tree.Left, spacing+" ")
|
||||
}
|
||||
|
||||
if tree.Right != nil {
|
||||
fmt.Println(spacing + "---> False")
|
||||
cprintTreeFromNode(*tree.Right, spacing+" ")
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += cprintTreeFromNode(*tree.Right, spacing+" ")
|
||||
}
|
||||
|
||||
return 0.0
|
||||
return returnString
|
||||
}
|
||||
|
||||
// Predict a single data point by traversing the entire tree
|
||||
@ -413,7 +416,7 @@ func cpredictSingle(tree CNode, instance []float64) int64 {
|
||||
}
|
||||
|
||||
// Predict is visible to user. Given test data, they receive predictions for every datapoint.
|
||||
func (tree *CTree) Predict(X_test base.FixedDataGrid) []int64 {
|
||||
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
||||
root := *tree.RootNode
|
||||
test := classifierConvertInstancesToProblemVec(X_test)
|
||||
return cpredictFromNode(root, test)
|
||||
@ -430,7 +433,7 @@ func cpredictFromNode(tree CNode, test [][]float64) []int64 {
|
||||
}
|
||||
|
||||
// Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding.
|
||||
func (tree *CTree) Evaluate(test base.FixedDataGrid) float64 {
|
||||
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
|
||||
rootNode := *tree.RootNode
|
||||
xTest := classifierConvertInstancesToProblemVec(test)
|
||||
yTest := classifierConvertInstancesToLabelVec(test)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
@ -22,8 +23,8 @@ type RNode struct {
|
||||
Use_not bool
|
||||
}
|
||||
|
||||
// RTree - Tree struct for Decision Tree Regressor
|
||||
type RTree struct {
|
||||
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
||||
type CARTDecisionTreeRegressor struct {
|
||||
RootNode *RNode
|
||||
criterion string
|
||||
maxDepth int64
|
||||
@ -125,8 +126,8 @@ func rgetFeature(data [][]float64, feature int64) []float64 {
|
||||
}
|
||||
|
||||
// Interface for creating new Decision Tree Regressor - cals rbestSplit()
|
||||
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *RTree {
|
||||
var tree RTree
|
||||
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
||||
var tree CARTDecisionTreeRegressor
|
||||
tree.maxDepth = maxDepth
|
||||
tree.criterion = strings.ToLower(criterion)
|
||||
return &tree
|
||||
@ -198,7 +199,7 @@ func rupdateSplit(left [][]float64, lefty []float64, right [][]float64, righty [
|
||||
}
|
||||
|
||||
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
|
||||
func (tree *RTree) Fit(X base.FixedDataGrid) {
|
||||
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
|
||||
var emptyNode RNode
|
||||
data := regressorConvertInstancesToProblemVec(X)
|
||||
y := regressorConvertInstancesToLabelVec(X)
|
||||
@ -209,7 +210,7 @@ func (tree *RTree) Fit(X base.FixedDataGrid) {
|
||||
}
|
||||
|
||||
// Essentially the Fit Method - Impelements recursive logic
|
||||
func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
|
||||
func rbestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
|
||||
|
||||
depth++
|
||||
|
||||
@ -334,72 +335,75 @@ func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, crit
|
||||
}
|
||||
|
||||
// Print Tree for Visualtion - calls printTreeFromNode()
|
||||
func (tree *RTree) PrintTree() {
|
||||
func (tree *CARTDecisionTreeRegressor) String() string {
|
||||
rootNode := *tree.RootNode
|
||||
printTreeFromNode(rootNode, "")
|
||||
return rprintTreeFromNode(rootNode, "")
|
||||
}
|
||||
|
||||
// Use tree's root node to print out entire tree
|
||||
func printTreeFromNode(tree RNode, spacing string) float64 {
|
||||
|
||||
fmt.Print(spacing + "Feature ")
|
||||
fmt.Print(tree.Feature)
|
||||
fmt.Print(" < ")
|
||||
fmt.Println(tree.Threshold)
|
||||
func rprintTreeFromNode(tree RNode, spacing string) string {
|
||||
returnString := ""
|
||||
returnString += spacing + "Feature "
|
||||
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||
returnString += " < "
|
||||
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||
returnString += "\n"
|
||||
|
||||
if tree.Left == nil {
|
||||
fmt.Println(spacing + "---> True")
|
||||
fmt.Print(" " + spacing + "PREDICT ")
|
||||
fmt.Println(tree.LeftPred)
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
|
||||
}
|
||||
if tree.Right == nil {
|
||||
fmt.Println(spacing + "---> FALSE")
|
||||
fmt.Print(" " + spacing + "PREDICT ")
|
||||
fmt.Println(tree.RightPred)
|
||||
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += " " + spacing + "PREDICT "
|
||||
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
|
||||
}
|
||||
|
||||
if tree.Left != nil {
|
||||
fmt.Println(spacing + "---> True")
|
||||
printTreeFromNode(*tree.Left, spacing+" ")
|
||||
// fmt.Println(spacing + "---> True")
|
||||
returnString += spacing + "---> True" + "\n"
|
||||
returnString += rprintTreeFromNode(*tree.Left, spacing+" ")
|
||||
}
|
||||
|
||||
if tree.Right != nil {
|
||||
fmt.Println(spacing + "---> False")
|
||||
printTreeFromNode(*tree.Right, spacing+" ")
|
||||
// fmt.Println(spacing + "---> False")
|
||||
returnString += spacing + "---> False" + "\n"
|
||||
returnString += rprintTreeFromNode(*tree.Right, spacing+" ")
|
||||
}
|
||||
|
||||
return 0.0
|
||||
return returnString
|
||||
}
|
||||
|
||||
// Predict a single data point
|
||||
func predictSingle(tree RNode, instance []float64) float64 {
|
||||
func rpredictSingle(tree RNode, instance []float64) float64 {
|
||||
if instance[tree.Feature] < tree.Threshold {
|
||||
if tree.Left == nil {
|
||||
return tree.LeftPred
|
||||
} else {
|
||||
return predictSingle(*tree.Left, instance)
|
||||
return rpredictSingle(*tree.Left, instance)
|
||||
}
|
||||
} else {
|
||||
if tree.Right == nil {
|
||||
return tree.RightPred
|
||||
} else {
|
||||
return predictSingle(*tree.Right, instance)
|
||||
return rpredictSingle(*tree.Right, instance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Predict method for multiple data points. Calls predictFromNode()
|
||||
func (tree *RTree) Predict(X_test base.FixedDataGrid) []float64 {
|
||||
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
||||
root := *tree.RootNode
|
||||
test := regressorConvertInstancesToProblemVec(X_test)
|
||||
return predictFromNode(root, test)
|
||||
return rpredictFromNode(root, test)
|
||||
}
|
||||
|
||||
// Use tree's root node to print out entire tree
|
||||
func predictFromNode(tree RNode, test [][]float64) []float64 {
|
||||
func rpredictFromNode(tree RNode, test [][]float64) []float64 {
|
||||
var preds []float64
|
||||
for i := range test {
|
||||
i_pred := predictSingle(tree, test[i])
|
||||
i_pred := rpredictSingle(tree, test[i])
|
||||
preds = append(preds, i_pred)
|
||||
}
|
||||
return preds
|
||||
|
Loading…
x
Reference in New Issue
Block a user