mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Adding Changes
This commit is contained in:
parent
08529c42cf
commit
c083759523
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
@ -23,8 +24,8 @@ type CNode struct {
|
|||||||
maxDepth int64
|
maxDepth int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// CTree: Tree struct for Decision Tree Classifier
|
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
||||||
type CTree struct {
|
type CARTDecisionTreeClassifier struct {
|
||||||
RootNode *CNode
|
RootNode *CNode
|
||||||
criterion string
|
criterion string
|
||||||
maxDepth int64
|
maxDepth int64
|
||||||
@ -135,8 +136,8 @@ func cgetFeature(data [][]float64, feature int64) []float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Function to Create New Decision Tree Classifier
|
// Function to Create New Decision Tree Classifier
|
||||||
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CTree {
|
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
||||||
var tree CTree
|
var tree CARTDecisionTreeClassifier
|
||||||
tree.criterion = strings.ToLower(criterion)
|
tree.criterion = strings.ToLower(criterion)
|
||||||
tree.maxDepth = maxDepth
|
tree.maxDepth = maxDepth
|
||||||
tree.labels = labels
|
tree.labels = labels
|
||||||
@ -210,7 +211,7 @@ func cupdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []i
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fit - Method visible to user to train tree
|
// Fit - Method visible to user to train tree
|
||||||
func (tree *CTree) Fit(X base.FixedDataGrid) {
|
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
||||||
var emptyNode CNode
|
var emptyNode CNode
|
||||||
|
|
||||||
data := classifierConvertInstancesToProblemVec(X)
|
data := classifierConvertInstancesToProblemVec(X)
|
||||||
@ -221,7 +222,7 @@ func (tree *CTree) Fit(X base.FixedDataGrid) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Iterativly find and record the best split - recursive function
|
// Iterativly find and record the best split - recursive function
|
||||||
func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
|
func cbestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
|
||||||
|
|
||||||
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
||||||
depth++
|
depth++
|
||||||
@ -358,41 +359,43 @@ func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PrintTree : this function prints out entire tree for visualization - visible to user
|
// PrintTree : this function prints out entire tree for visualization - visible to user
|
||||||
func (tree *CTree) PrintTree() {
|
func (tree *CARTDecisionTreeClassifier) String() string {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
cprintTreeFromNode(rootNode, "")
|
return cprintTreeFromNode(rootNode, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tree struct has root node. That is used to print tree - invisible to user but called from PrintTree
|
func cprintTreeFromNode(tree CNode, spacing string) string {
|
||||||
func cprintTreeFromNode(tree CNode, spacing string) float64 {
|
returnString := ""
|
||||||
|
returnString += spacing + "Feature "
|
||||||
fmt.Print(spacing + "Feature ")
|
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||||
fmt.Print(tree.Feature)
|
returnString += " < "
|
||||||
fmt.Print(" < ")
|
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||||
fmt.Println(tree.Threshold)
|
returnString += "\n"
|
||||||
|
|
||||||
if tree.Left == nil {
|
if tree.Left == nil {
|
||||||
fmt.Println(spacing + "---> True")
|
returnString += spacing + "---> True" + "\n"
|
||||||
fmt.Print(" " + spacing + "PREDICT ")
|
returnString += " " + spacing + "PREDICT "
|
||||||
fmt.Println(tree.LeftLabel)
|
returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
|
||||||
|
|
||||||
}
|
}
|
||||||
if tree.Right == nil {
|
if tree.Right == nil {
|
||||||
fmt.Println(spacing + "---> FALSE")
|
|
||||||
fmt.Print(" " + spacing + "PREDICT ")
|
returnString += spacing + "---> False" + "\n"
|
||||||
fmt.Println(tree.RightLabel)
|
returnString += " " + spacing + "PREDICT "
|
||||||
|
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Left != nil {
|
if tree.Left != nil {
|
||||||
fmt.Println(spacing + "---> True")
|
returnString += spacing + "---> True" + "\n"
|
||||||
cprintTreeFromNode(*tree.Left, spacing+" ")
|
returnString += cprintTreeFromNode(*tree.Left, spacing+" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Right != nil {
|
if tree.Right != nil {
|
||||||
fmt.Println(spacing + "---> False")
|
returnString += spacing + "---> False" + "\n"
|
||||||
cprintTreeFromNode(*tree.Right, spacing+" ")
|
returnString += cprintTreeFromNode(*tree.Right, spacing+" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0.0
|
return returnString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict a single data point by traversing the entire tree
|
// Predict a single data point by traversing the entire tree
|
||||||
@ -413,7 +416,7 @@ func cpredictSingle(tree CNode, instance []float64) int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Predict is visible to user. Given test data, they receive predictions for every datapoint.
|
// Predict is visible to user. Given test data, they receive predictions for every datapoint.
|
||||||
func (tree *CTree) Predict(X_test base.FixedDataGrid) []int64 {
|
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
||||||
root := *tree.RootNode
|
root := *tree.RootNode
|
||||||
test := classifierConvertInstancesToProblemVec(X_test)
|
test := classifierConvertInstancesToProblemVec(X_test)
|
||||||
return cpredictFromNode(root, test)
|
return cpredictFromNode(root, test)
|
||||||
@ -430,7 +433,7 @@ func cpredictFromNode(tree CNode, test [][]float64) []int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding.
|
// Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding.
|
||||||
func (tree *CTree) Evaluate(test base.FixedDataGrid) float64 {
|
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
xTest := classifierConvertInstancesToProblemVec(test)
|
xTest := classifierConvertInstancesToProblemVec(test)
|
||||||
yTest := classifierConvertInstancesToLabelVec(test)
|
yTest := classifierConvertInstancesToLabelVec(test)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
@ -22,8 +23,8 @@ type RNode struct {
|
|||||||
Use_not bool
|
Use_not bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// RTree - Tree struct for Decision Tree Regressor
|
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
||||||
type RTree struct {
|
type CARTDecisionTreeRegressor struct {
|
||||||
RootNode *RNode
|
RootNode *RNode
|
||||||
criterion string
|
criterion string
|
||||||
maxDepth int64
|
maxDepth int64
|
||||||
@ -125,8 +126,8 @@ func rgetFeature(data [][]float64, feature int64) []float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Interface for creating new Decision Tree Regressor - cals rbestSplit()
|
// Interface for creating new Decision Tree Regressor - cals rbestSplit()
|
||||||
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *RTree {
|
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
||||||
var tree RTree
|
var tree CARTDecisionTreeRegressor
|
||||||
tree.maxDepth = maxDepth
|
tree.maxDepth = maxDepth
|
||||||
tree.criterion = strings.ToLower(criterion)
|
tree.criterion = strings.ToLower(criterion)
|
||||||
return &tree
|
return &tree
|
||||||
@ -198,7 +199,7 @@ func rupdateSplit(left [][]float64, lefty []float64, right [][]float64, righty [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
|
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
|
||||||
func (tree *RTree) Fit(X base.FixedDataGrid) {
|
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
|
||||||
var emptyNode RNode
|
var emptyNode RNode
|
||||||
data := regressorConvertInstancesToProblemVec(X)
|
data := regressorConvertInstancesToProblemVec(X)
|
||||||
y := regressorConvertInstancesToLabelVec(X)
|
y := regressorConvertInstancesToLabelVec(X)
|
||||||
@ -209,7 +210,7 @@ func (tree *RTree) Fit(X base.FixedDataGrid) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Essentially the Fit Method - Impelements recursive logic
|
// Essentially the Fit Method - Impelements recursive logic
|
||||||
func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
|
func rbestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
|
||||||
|
|
||||||
depth++
|
depth++
|
||||||
|
|
||||||
@ -334,72 +335,75 @@ func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, crit
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Print Tree for Visualtion - calls printTreeFromNode()
|
// Print Tree for Visualtion - calls printTreeFromNode()
|
||||||
func (tree *RTree) PrintTree() {
|
func (tree *CARTDecisionTreeRegressor) String() string {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
printTreeFromNode(rootNode, "")
|
return rprintTreeFromNode(rootNode, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use tree's root node to print out entire tree
|
func rprintTreeFromNode(tree RNode, spacing string) string {
|
||||||
func printTreeFromNode(tree RNode, spacing string) float64 {
|
returnString := ""
|
||||||
|
returnString += spacing + "Feature "
|
||||||
fmt.Print(spacing + "Feature ")
|
returnString += strconv.FormatInt(tree.Feature, 10)
|
||||||
fmt.Print(tree.Feature)
|
returnString += " < "
|
||||||
fmt.Print(" < ")
|
returnString += fmt.Sprintf("%.3f", tree.Threshold)
|
||||||
fmt.Println(tree.Threshold)
|
returnString += "\n"
|
||||||
|
|
||||||
if tree.Left == nil {
|
if tree.Left == nil {
|
||||||
fmt.Println(spacing + "---> True")
|
returnString += spacing + "---> True" + "\n"
|
||||||
fmt.Print(" " + spacing + "PREDICT ")
|
returnString += " " + spacing + "PREDICT "
|
||||||
fmt.Println(tree.LeftPred)
|
returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
|
||||||
}
|
}
|
||||||
if tree.Right == nil {
|
if tree.Right == nil {
|
||||||
fmt.Println(spacing + "---> FALSE")
|
|
||||||
fmt.Print(" " + spacing + "PREDICT ")
|
returnString += spacing + "---> False" + "\n"
|
||||||
fmt.Println(tree.RightPred)
|
returnString += " " + spacing + "PREDICT "
|
||||||
|
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Left != nil {
|
if tree.Left != nil {
|
||||||
fmt.Println(spacing + "---> True")
|
// fmt.Println(spacing + "---> True")
|
||||||
printTreeFromNode(*tree.Left, spacing+" ")
|
returnString += spacing + "---> True" + "\n"
|
||||||
|
returnString += rprintTreeFromNode(*tree.Left, spacing+" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
if tree.Right != nil {
|
if tree.Right != nil {
|
||||||
fmt.Println(spacing + "---> False")
|
// fmt.Println(spacing + "---> False")
|
||||||
printTreeFromNode(*tree.Right, spacing+" ")
|
returnString += spacing + "---> False" + "\n"
|
||||||
|
returnString += rprintTreeFromNode(*tree.Right, spacing+" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0.0
|
return returnString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict a single data point
|
// Predict a single data point
|
||||||
func predictSingle(tree RNode, instance []float64) float64 {
|
func rpredictSingle(tree RNode, instance []float64) float64 {
|
||||||
if instance[tree.Feature] < tree.Threshold {
|
if instance[tree.Feature] < tree.Threshold {
|
||||||
if tree.Left == nil {
|
if tree.Left == nil {
|
||||||
return tree.LeftPred
|
return tree.LeftPred
|
||||||
} else {
|
} else {
|
||||||
return predictSingle(*tree.Left, instance)
|
return rpredictSingle(*tree.Left, instance)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if tree.Right == nil {
|
if tree.Right == nil {
|
||||||
return tree.RightPred
|
return tree.RightPred
|
||||||
} else {
|
} else {
|
||||||
return predictSingle(*tree.Right, instance)
|
return rpredictSingle(*tree.Right, instance)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict method for multiple data points. Calls predictFromNode()
|
// Predict method for multiple data points. Calls predictFromNode()
|
||||||
func (tree *RTree) Predict(X_test base.FixedDataGrid) []float64 {
|
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
||||||
root := *tree.RootNode
|
root := *tree.RootNode
|
||||||
test := regressorConvertInstancesToProblemVec(X_test)
|
test := regressorConvertInstancesToProblemVec(X_test)
|
||||||
return predictFromNode(root, test)
|
return rpredictFromNode(root, test)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use tree's root node to print out entire tree
|
// Use tree's root node to print out entire tree
|
||||||
func predictFromNode(tree RNode, test [][]float64) []float64 {
|
func rpredictFromNode(tree RNode, test [][]float64) []float64 {
|
||||||
var preds []float64
|
var preds []float64
|
||||||
for i := range test {
|
for i := range test {
|
||||||
i_pred := predictSingle(tree, test[i])
|
i_pred := rpredictSingle(tree, test[i])
|
||||||
preds = append(preds, i_pred)
|
preds = append(preds, i_pred)
|
||||||
}
|
}
|
||||||
return preds
|
return preds
|
||||||
|
Loading…
x
Reference in New Issue
Block a user