1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Adding Changes

This commit is contained in:
Ayush 2020-07-22 14:34:59 +05:30
parent 08529c42cf
commit c083759523
2 changed files with 68 additions and 61 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/sjwhitworth/golearn/base"
@ -23,8 +24,8 @@ type CNode struct {
maxDepth int64
}
// CTree: Tree struct for Decision Tree Classifier
type CTree struct {
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
type CARTDecisionTreeClassifier struct {
RootNode *CNode
criterion string
maxDepth int64
@ -135,8 +136,8 @@ func cgetFeature(data [][]float64, feature int64) []float64 {
}
// Function to Create New Decision Tree Classifier
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CTree {
var tree CTree
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
var tree CARTDecisionTreeClassifier
tree.criterion = strings.ToLower(criterion)
tree.maxDepth = maxDepth
tree.labels = labels
@ -210,7 +211,7 @@ func cupdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []i
}
// Fit - Method visible to user to train tree
func (tree *CTree) Fit(X base.FixedDataGrid) {
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
var emptyNode CNode
data := classifierConvertInstancesToProblemVec(X)
@ -221,7 +222,7 @@ func (tree *CTree) Fit(X base.FixedDataGrid) {
}
// Iterativly find and record the best split - recursive function
func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
func cbestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
depth++
@ -358,41 +359,43 @@ func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNo
}
// PrintTree : this function prints out entire tree for visualization - visible to user
func (tree *CTree) PrintTree() {
func (tree *CARTDecisionTreeClassifier) String() string {
rootNode := *tree.RootNode
cprintTreeFromNode(rootNode, "")
return cprintTreeFromNode(rootNode, "")
}
// Tree struct has root node. That is used to print tree - invisible to user but called from PrintTree
func cprintTreeFromNode(tree CNode, spacing string) float64 {
fmt.Print(spacing + "Feature ")
fmt.Print(tree.Feature)
fmt.Print(" < ")
fmt.Println(tree.Threshold)
func cprintTreeFromNode(tree CNode, spacing string) string {
returnString := ""
returnString += spacing + "Feature "
returnString += strconv.FormatInt(tree.Feature, 10)
returnString += " < "
returnString += fmt.Sprintf("%.3f", tree.Threshold)
returnString += "\n"
if tree.Left == nil {
fmt.Println(spacing + "---> True")
fmt.Print(" " + spacing + "PREDICT ")
fmt.Println(tree.LeftLabel)
returnString += spacing + "---> True" + "\n"
returnString += " " + spacing + "PREDICT "
returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
}
if tree.Right == nil {
fmt.Println(spacing + "---> FALSE")
fmt.Print(" " + spacing + "PREDICT ")
fmt.Println(tree.RightLabel)
returnString += spacing + "---> False" + "\n"
returnString += " " + spacing + "PREDICT "
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
}
if tree.Left != nil {
fmt.Println(spacing + "---> True")
cprintTreeFromNode(*tree.Left, spacing+" ")
returnString += spacing + "---> True" + "\n"
returnString += cprintTreeFromNode(*tree.Left, spacing+" ")
}
if tree.Right != nil {
fmt.Println(spacing + "---> False")
cprintTreeFromNode(*tree.Right, spacing+" ")
returnString += spacing + "---> False" + "\n"
returnString += cprintTreeFromNode(*tree.Right, spacing+" ")
}
return 0.0
return returnString
}
// Predict a single data point by traversing the entire tree
@ -413,7 +416,7 @@ func cpredictSingle(tree CNode, instance []float64) int64 {
}
// Predict is visible to user. Given test data, they receive predictions for every datapoint.
func (tree *CTree) Predict(X_test base.FixedDataGrid) []int64 {
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
root := *tree.RootNode
test := classifierConvertInstancesToProblemVec(X_test)
return cpredictFromNode(root, test)
@ -430,7 +433,7 @@ func cpredictFromNode(tree CNode, test [][]float64) []int64 {
}
// Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding.
func (tree *CTree) Evaluate(test base.FixedDataGrid) float64 {
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
rootNode := *tree.RootNode
xTest := classifierConvertInstancesToProblemVec(test)
yTest := classifierConvertInstancesToLabelVec(test)

View File

@ -4,6 +4,7 @@ import (
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/sjwhitworth/golearn/base"
@ -22,8 +23,8 @@ type RNode struct {
Use_not bool
}
// RTree - Tree struct for Decision Tree Regressor
type RTree struct {
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
type CARTDecisionTreeRegressor struct {
RootNode *RNode
criterion string
maxDepth int64
@ -125,8 +126,8 @@ func rgetFeature(data [][]float64, feature int64) []float64 {
}
// Interface for creating new Decision Tree Regressor - cals rbestSplit()
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *RTree {
var tree RTree
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
var tree CARTDecisionTreeRegressor
tree.maxDepth = maxDepth
tree.criterion = strings.ToLower(criterion)
return &tree
@ -198,7 +199,7 @@ func rupdateSplit(left [][]float64, lefty []float64, right [][]float64, righty [
}
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
func (tree *RTree) Fit(X base.FixedDataGrid) {
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
var emptyNode RNode
data := regressorConvertInstancesToProblemVec(X)
y := regressorConvertInstancesToLabelVec(X)
@ -209,7 +210,7 @@ func (tree *RTree) Fit(X base.FixedDataGrid) {
}
// Essentially the Fit Method - Impelements recursive logic
func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
func rbestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
depth++
@ -334,72 +335,75 @@ func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, crit
}
// Print Tree for Visualtion - calls printTreeFromNode()
func (tree *RTree) PrintTree() {
func (tree *CARTDecisionTreeRegressor) String() string {
rootNode := *tree.RootNode
printTreeFromNode(rootNode, "")
return rprintTreeFromNode(rootNode, "")
}
// Use tree's root node to print out entire tree
func printTreeFromNode(tree RNode, spacing string) float64 {
fmt.Print(spacing + "Feature ")
fmt.Print(tree.Feature)
fmt.Print(" < ")
fmt.Println(tree.Threshold)
func rprintTreeFromNode(tree RNode, spacing string) string {
returnString := ""
returnString += spacing + "Feature "
returnString += strconv.FormatInt(tree.Feature, 10)
returnString += " < "
returnString += fmt.Sprintf("%.3f", tree.Threshold)
returnString += "\n"
if tree.Left == nil {
fmt.Println(spacing + "---> True")
fmt.Print(" " + spacing + "PREDICT ")
fmt.Println(tree.LeftPred)
returnString += spacing + "---> True" + "\n"
returnString += " " + spacing + "PREDICT "
returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
}
if tree.Right == nil {
fmt.Println(spacing + "---> FALSE")
fmt.Print(" " + spacing + "PREDICT ")
fmt.Println(tree.RightPred)
returnString += spacing + "---> False" + "\n"
returnString += " " + spacing + "PREDICT "
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
}
if tree.Left != nil {
fmt.Println(spacing + "---> True")
printTreeFromNode(*tree.Left, spacing+" ")
// fmt.Println(spacing + "---> True")
returnString += spacing + "---> True" + "\n"
returnString += rprintTreeFromNode(*tree.Left, spacing+" ")
}
if tree.Right != nil {
fmt.Println(spacing + "---> False")
printTreeFromNode(*tree.Right, spacing+" ")
// fmt.Println(spacing + "---> False")
returnString += spacing + "---> False" + "\n"
returnString += rprintTreeFromNode(*tree.Right, spacing+" ")
}
return 0.0
return returnString
}
// Predict a single data point
func predictSingle(tree RNode, instance []float64) float64 {
func rpredictSingle(tree RNode, instance []float64) float64 {
if instance[tree.Feature] < tree.Threshold {
if tree.Left == nil {
return tree.LeftPred
} else {
return predictSingle(*tree.Left, instance)
return rpredictSingle(*tree.Left, instance)
}
} else {
if tree.Right == nil {
return tree.RightPred
} else {
return predictSingle(*tree.Right, instance)
return rpredictSingle(*tree.Right, instance)
}
}
}
// Predict method for multiple data points. Calls predictFromNode()
func (tree *RTree) Predict(X_test base.FixedDataGrid) []float64 {
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
root := *tree.RootNode
test := regressorConvertInstancesToProblemVec(X_test)
return predictFromNode(root, test)
return rpredictFromNode(root, test)
}
// Use tree's root node to print out entire tree
func predictFromNode(tree RNode, test [][]float64) []float64 {
func rpredictFromNode(tree RNode, test [][]float64) []float64 {
var preds []float64
for i := range test {
i_pred := predictSingle(tree, test[i])
i_pred := rpredictSingle(tree, test[i])
preds = append(preds, i_pred)
}
return preds