1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Adding Changes

This commit is contained in:
Ayush 2020-07-22 14:34:59 +05:30
parent 08529c42cf
commit c083759523
2 changed files with 68 additions and 61 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"math" "math"
"sort" "sort"
"strconv"
"strings" "strings"
"github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/base"
@ -23,8 +24,8 @@ type CNode struct {
maxDepth int64 maxDepth int64
} }
// CTree: Tree struct for Decision Tree Classifier // CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
type CTree struct { type CARTDecisionTreeClassifier struct {
RootNode *CNode RootNode *CNode
criterion string criterion string
maxDepth int64 maxDepth int64
@ -135,8 +136,8 @@ func cgetFeature(data [][]float64, feature int64) []float64 {
} }
// Function to Create New Decision Tree Classifier // Function to Create New Decision Tree Classifier
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CTree { func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
var tree CTree var tree CARTDecisionTreeClassifier
tree.criterion = strings.ToLower(criterion) tree.criterion = strings.ToLower(criterion)
tree.maxDepth = maxDepth tree.maxDepth = maxDepth
tree.labels = labels tree.labels = labels
@ -210,7 +211,7 @@ func cupdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []i
} }
// Fit - Method visible to user to train tree // Fit - Method visible to user to train tree
func (tree *CTree) Fit(X base.FixedDataGrid) { func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
var emptyNode CNode var emptyNode CNode
data := classifierConvertInstancesToProblemVec(X) data := classifierConvertInstancesToProblemVec(X)
@ -221,7 +222,7 @@ func (tree *CTree) Fit(X base.FixedDataGrid) {
} }
// Iterativly find and record the best split - recursive function // Iterativly find and record the best split - recursive function
func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode { func cbestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode CNode, criterion string, maxDepth int64, depth int64) CNode {
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure // Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
depth++ depth++
@ -358,41 +359,43 @@ func cbestSplit(tree CTree, data [][]float64, y []int64, labels []int64, upperNo
} }
// PrintTree : this function prints out entire tree for visualization - visible to user // PrintTree : this function prints out entire tree for visualization - visible to user
func (tree *CTree) PrintTree() { func (tree *CARTDecisionTreeClassifier) String() string {
rootNode := *tree.RootNode rootNode := *tree.RootNode
cprintTreeFromNode(rootNode, "") return cprintTreeFromNode(rootNode, "")
} }
// Tree struct has root node. That is used to print tree - invisible to user but called from PrintTree func cprintTreeFromNode(tree CNode, spacing string) string {
func cprintTreeFromNode(tree CNode, spacing string) float64 { returnString := ""
returnString += spacing + "Feature "
fmt.Print(spacing + "Feature ") returnString += strconv.FormatInt(tree.Feature, 10)
fmt.Print(tree.Feature) returnString += " < "
fmt.Print(" < ") returnString += fmt.Sprintf("%.3f", tree.Threshold)
fmt.Println(tree.Threshold) returnString += "\n"
if tree.Left == nil { if tree.Left == nil {
fmt.Println(spacing + "---> True") returnString += spacing + "---> True" + "\n"
fmt.Print(" " + spacing + "PREDICT ") returnString += " " + spacing + "PREDICT "
fmt.Println(tree.LeftLabel) returnString += strconv.FormatInt(tree.LeftLabel, 10) + "\n"
} }
if tree.Right == nil { if tree.Right == nil {
fmt.Println(spacing + "---> FALSE")
fmt.Print(" " + spacing + "PREDICT ") returnString += spacing + "---> False" + "\n"
fmt.Println(tree.RightLabel) returnString += " " + spacing + "PREDICT "
returnString += strconv.FormatInt(tree.RightLabel, 10) + "\n"
} }
if tree.Left != nil { if tree.Left != nil {
fmt.Println(spacing + "---> True") returnString += spacing + "---> True" + "\n"
cprintTreeFromNode(*tree.Left, spacing+" ") returnString += cprintTreeFromNode(*tree.Left, spacing+" ")
} }
if tree.Right != nil { if tree.Right != nil {
fmt.Println(spacing + "---> False") returnString += spacing + "---> False" + "\n"
cprintTreeFromNode(*tree.Right, spacing+" ") returnString += cprintTreeFromNode(*tree.Right, spacing+" ")
} }
return 0.0 return returnString
} }
// Predict a single data point by traversing the entire tree // Predict a single data point by traversing the entire tree
@ -413,7 +416,7 @@ func cpredictSingle(tree CNode, instance []float64) int64 {
} }
// Predict is visible to user. Given test data, they receive predictions for every datapoint. // Predict is visible to user. Given test data, they receive predictions for every datapoint.
func (tree *CTree) Predict(X_test base.FixedDataGrid) []int64 { func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
root := *tree.RootNode root := *tree.RootNode
test := classifierConvertInstancesToProblemVec(X_test) test := classifierConvertInstancesToProblemVec(X_test)
return cpredictFromNode(root, test) return cpredictFromNode(root, test)
@ -430,7 +433,7 @@ func cpredictFromNode(tree CNode, test [][]float64) []int64 {
} }
// Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding. // Given Test data and label, return the accuracy of the classifier. Data has to be in float slice format before feeding.
func (tree *CTree) Evaluate(test base.FixedDataGrid) float64 { func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
rootNode := *tree.RootNode rootNode := *tree.RootNode
xTest := classifierConvertInstancesToProblemVec(test) xTest := classifierConvertInstancesToProblemVec(test)
yTest := classifierConvertInstancesToLabelVec(test) yTest := classifierConvertInstancesToLabelVec(test)

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"math" "math"
"sort" "sort"
"strconv"
"strings" "strings"
"github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/base"
@ -22,8 +23,8 @@ type RNode struct {
Use_not bool Use_not bool
} }
// RTree - Tree struct for Decision Tree Regressor // CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
type RTree struct { type CARTDecisionTreeRegressor struct {
RootNode *RNode RootNode *RNode
criterion string criterion string
maxDepth int64 maxDepth int64
@ -125,8 +126,8 @@ func rgetFeature(data [][]float64, feature int64) []float64 {
} }
// Interface for creating new Decision Tree Regressor - cals rbestSplit() // Interface for creating new Decision Tree Regressor - cals rbestSplit()
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *RTree { func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
var tree RTree var tree CARTDecisionTreeRegressor
tree.maxDepth = maxDepth tree.maxDepth = maxDepth
tree.criterion = strings.ToLower(criterion) tree.criterion = strings.ToLower(criterion)
return &tree return &tree
@ -198,7 +199,7 @@ func rupdateSplit(left [][]float64, lefty []float64, right [][]float64, righty [
} }
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic. // Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
func (tree *RTree) Fit(X base.FixedDataGrid) { func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
var emptyNode RNode var emptyNode RNode
data := regressorConvertInstancesToProblemVec(X) data := regressorConvertInstancesToProblemVec(X)
y := regressorConvertInstancesToLabelVec(X) y := regressorConvertInstancesToLabelVec(X)
@ -209,7 +210,7 @@ func (tree *RTree) Fit(X base.FixedDataGrid) {
} }
// Essentially the Fit Method - Impelements recursive logic // Essentially the Fit Method - Impelements recursive logic
func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode { func rbestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode RNode, criterion string, maxDepth int64, depth int64) RNode {
depth++ depth++
@ -334,72 +335,75 @@ func rbestSplit(tree RTree, data [][]float64, y []float64, upperNode RNode, crit
} }
// Print Tree for Visualtion - calls printTreeFromNode() // Print Tree for Visualtion - calls printTreeFromNode()
func (tree *RTree) PrintTree() { func (tree *CARTDecisionTreeRegressor) String() string {
rootNode := *tree.RootNode rootNode := *tree.RootNode
printTreeFromNode(rootNode, "") return rprintTreeFromNode(rootNode, "")
} }
// Use tree's root node to print out entire tree func rprintTreeFromNode(tree RNode, spacing string) string {
func printTreeFromNode(tree RNode, spacing string) float64 { returnString := ""
returnString += spacing + "Feature "
fmt.Print(spacing + "Feature ") returnString += strconv.FormatInt(tree.Feature, 10)
fmt.Print(tree.Feature) returnString += " < "
fmt.Print(" < ") returnString += fmt.Sprintf("%.3f", tree.Threshold)
fmt.Println(tree.Threshold) returnString += "\n"
if tree.Left == nil { if tree.Left == nil {
fmt.Println(spacing + "---> True") returnString += spacing + "---> True" + "\n"
fmt.Print(" " + spacing + "PREDICT ") returnString += " " + spacing + "PREDICT "
fmt.Println(tree.LeftPred) returnString += fmt.Sprintf("%.3f", tree.LeftPred) + "\n"
} }
if tree.Right == nil { if tree.Right == nil {
fmt.Println(spacing + "---> FALSE")
fmt.Print(" " + spacing + "PREDICT ") returnString += spacing + "---> False" + "\n"
fmt.Println(tree.RightPred) returnString += " " + spacing + "PREDICT "
returnString += fmt.Sprintf("%.3f", tree.RightPred) + "\n"
} }
if tree.Left != nil { if tree.Left != nil {
fmt.Println(spacing + "---> True") // fmt.Println(spacing + "---> True")
printTreeFromNode(*tree.Left, spacing+" ") returnString += spacing + "---> True" + "\n"
returnString += rprintTreeFromNode(*tree.Left, spacing+" ")
} }
if tree.Right != nil { if tree.Right != nil {
fmt.Println(spacing + "---> False") // fmt.Println(spacing + "---> False")
printTreeFromNode(*tree.Right, spacing+" ") returnString += spacing + "---> False" + "\n"
returnString += rprintTreeFromNode(*tree.Right, spacing+" ")
} }
return 0.0 return returnString
} }
// Predict a single data point // Predict a single data point
func predictSingle(tree RNode, instance []float64) float64 { func rpredictSingle(tree RNode, instance []float64) float64 {
if instance[tree.Feature] < tree.Threshold { if instance[tree.Feature] < tree.Threshold {
if tree.Left == nil { if tree.Left == nil {
return tree.LeftPred return tree.LeftPred
} else { } else {
return predictSingle(*tree.Left, instance) return rpredictSingle(*tree.Left, instance)
} }
} else { } else {
if tree.Right == nil { if tree.Right == nil {
return tree.RightPred return tree.RightPred
} else { } else {
return predictSingle(*tree.Right, instance) return rpredictSingle(*tree.Right, instance)
} }
} }
} }
// Predict method for multiple data points. Calls predictFromNode() // Predict method for multiple data points. Calls predictFromNode()
func (tree *RTree) Predict(X_test base.FixedDataGrid) []float64 { func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
root := *tree.RootNode root := *tree.RootNode
test := regressorConvertInstancesToProblemVec(X_test) test := regressorConvertInstancesToProblemVec(X_test)
return predictFromNode(root, test) return rpredictFromNode(root, test)
} }
// Use tree's root node to print out entire tree // Use tree's root node to print out entire tree
func predictFromNode(tree RNode, test [][]float64) []float64 { func rpredictFromNode(tree RNode, test [][]float64) []float64 {
var preds []float64 var preds []float64
for i := range test { for i := range test {
i_pred := predictSingle(tree, test[i]) i_pred := rpredictSingle(tree, test[i])
preds = append(preds, i_pred) preds = append(preds, i_pred)
} }
return preds return preds