mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Fixing Comments
This commit is contained in:
parent
abed408f9b
commit
91a27e3ca0
@ -36,7 +36,6 @@ func main() {
|
|||||||
fmt.Println(decTree.Evaluate(testData))
|
fmt.Println(decTree.Evaluate(testData))
|
||||||
|
|
||||||
// Load House Price Data For Regression
|
// Load House Price Data For Regression
|
||||||
|
|
||||||
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
|
regressionData, err := base.ParseCSVToInstances("../datasets/boston_house_prices.csv", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -10,9 +10,13 @@ import (
|
|||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The "c" prefix to function names indicates that they were tailored for classification
|
const (
|
||||||
|
GINI string = "gini"
|
||||||
|
ENTROPY string = "entropy"
|
||||||
|
)
|
||||||
|
|
||||||
// CNode is Node struct for Decision Tree Classifier
|
// CNode is Node struct for Decision Tree Classifier.
|
||||||
|
// It holds the information for each split (which feature to use, what threshold, and which label to assign for each side of the split)
|
||||||
type classifierNode struct {
|
type classifierNode struct {
|
||||||
Left *classifierNode
|
Left *classifierNode
|
||||||
Right *classifierNode
|
Right *classifierNode
|
||||||
@ -25,6 +29,8 @@ type classifierNode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
// CARTDecisionTreeClassifier: Tree struct for Decision Tree Classifier
|
||||||
|
// It contains the rootNode, as well as all of the hyperparameters chosen by the user.
|
||||||
|
// It also keeps track of all splits done at the tree level.
|
||||||
type CARTDecisionTreeClassifier struct {
|
type CARTDecisionTreeClassifier struct {
|
||||||
RootNode *classifierNode
|
RootNode *classifierNode
|
||||||
criterion string
|
criterion string
|
||||||
@ -84,7 +90,7 @@ func entropy(y []int64, labels []int64) (float64, int64) {
|
|||||||
return entropy, maxLabel
|
return entropy, maxLabel
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split the data into left node and right node based on feature and threshold - only needed for fresh nodes
|
// Split the data into left node and right node based on feature and threshold
|
||||||
func classifierCreateSplit(data [][]float64, feature int64, y []int64, threshold float64) ([][]float64, [][]float64, []int64, []int64) {
|
func classifierCreateSplit(data [][]float64, feature int64, y []int64, threshold float64) ([][]float64, [][]float64, []int64, []int64) {
|
||||||
var left [][]float64
|
var left [][]float64
|
||||||
var right [][]float64
|
var right [][]float64
|
||||||
@ -105,7 +111,8 @@ func classifierCreateSplit(data [][]float64, feature int64, y []int64, threshold
|
|||||||
return left, right, lefty, righty
|
return left, right, lefty, righty
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper Function to check if data point is unique or not
|
// Helper Function to check if data point is unique or not.
|
||||||
|
// We will use this to isolate unique values of a feature
|
||||||
func classifierStringInSlice(a float64, list []float64) bool {
|
func classifierStringInSlice(a float64, list []float64) bool {
|
||||||
for _, b := range list {
|
for _, b := range list {
|
||||||
if b == a {
|
if b == a {
|
||||||
@ -115,7 +122,7 @@ func classifierStringInSlice(a float64, list []float64) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Isolate only unique values. Needed for splitting data.
|
// Isolate only unique values. This way, we can try only unique splits and not redundant ones.
|
||||||
func classifierFindUnique(data []float64) []float64 {
|
func classifierFindUnique(data []float64) []float64 {
|
||||||
var unique []float64
|
var unique []float64
|
||||||
for i := range data {
|
for i := range data {
|
||||||
@ -126,7 +133,7 @@ func classifierFindUnique(data []float64) []float64 {
|
|||||||
return unique
|
return unique
|
||||||
}
|
}
|
||||||
|
|
||||||
// Isolate only the feature being considered for splitting
|
// Isolate only the feature being considered for splitting. Reduces the complexity in managing splits.
|
||||||
func classifierGetFeature(data [][]float64, feature int64) []float64 {
|
func classifierGetFeature(data [][]float64, feature int64) []float64 {
|
||||||
var featureVals []float64
|
var featureVals []float64
|
||||||
for i := range data {
|
for i := range data {
|
||||||
@ -135,7 +142,8 @@ func classifierGetFeature(data [][]float64, feature int64) []float64 {
|
|||||||
return featureVals
|
return featureVals
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to Create New Decision Tree Classifier
|
// Function to Create New Decision Tree Classifier.
|
||||||
|
// It assigns all of the hyperparameters by user into the tree attributes.
|
||||||
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64) *CARTDecisionTreeClassifier {
|
||||||
var tree CARTDecisionTreeClassifier
|
var tree CARTDecisionTreeClassifier
|
||||||
tree.criterion = strings.ToLower(criterion)
|
tree.criterion = strings.ToLower(criterion)
|
||||||
@ -145,7 +153,8 @@ func NewDecisionTreeClassifier(criterion string, maxDepth int64, labels []int64)
|
|||||||
return &tree
|
return &tree
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure that split being considered has not been done before
|
// Make sure that split being considered has not been done before.
|
||||||
|
// Else we will unnecessarily try splits that won't improve Impurity.
|
||||||
func classifierValidate(triedSplits [][]float64, feature int64, threshold float64) bool {
|
func classifierValidate(triedSplits [][]float64, feature int64, threshold float64) bool {
|
||||||
for i := range triedSplits {
|
for i := range triedSplits {
|
||||||
split := triedSplits[i]
|
split := triedSplits[i]
|
||||||
@ -175,7 +184,7 @@ func classifierReOrderData(featureVal []float64, data [][]float64, y []int64) ([
|
|||||||
return dataSorted, ySorted
|
return dataSorted, ySorted
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change data in Left Node and Right Node based on change in threshold
|
// Update the left and right side of the split based on the threshold.
|
||||||
func classifierUpdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []int64, feature int64, threshold float64) ([][]float64, []int64, [][]float64, []int64) {
|
func classifierUpdateSplit(left [][]float64, lefty []int64, right [][]float64, righty []int64, feature int64, threshold float64) ([][]float64, []int64, [][]float64, []int64) {
|
||||||
|
|
||||||
for right[0][feature] < threshold {
|
for right[0][feature] < threshold {
|
||||||
@ -188,7 +197,8 @@ func classifierUpdateSplit(left [][]float64, lefty []int64, right [][]float64, r
|
|||||||
return left, lefty, right, righty
|
return left, lefty, right, righty
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fit - Method visible to user to train tree
|
// Fit - Creates an Emppty Root Node
|
||||||
|
// Trains the tree by calling recursive function classifierBestSplit
|
||||||
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
||||||
var emptyNode classifierNode
|
var emptyNode classifierNode
|
||||||
|
|
||||||
@ -199,7 +209,8 @@ func (tree *CARTDecisionTreeClassifier) Fit(X base.FixedDataGrid) {
|
|||||||
tree.RootNode = &emptyNode
|
tree.RootNode = &emptyNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterativly find and record the best split - recursive function
|
// Iterativly find and record the best split
|
||||||
|
// Stop If depth reaches maxDepth or nodes are pure
|
||||||
func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode classifierNode, criterion string, maxDepth int64, depth int64) classifierNode {
|
func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []int64, labels []int64, upperNode classifierNode, criterion string, maxDepth int64, depth int64) classifierNode {
|
||||||
|
|
||||||
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
// Ensure that we have not reached maxDepth. maxDepth =-1 means split until nodes are pure
|
||||||
@ -214,9 +225,9 @@ func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []
|
|||||||
var origGini float64
|
var origGini float64
|
||||||
|
|
||||||
// Calculate loss based on Criterion Specified by user
|
// Calculate loss based on Criterion Specified by user
|
||||||
if criterion == "gini" {
|
if criterion == GINI {
|
||||||
origGini, upperNode.LeftLabel = giniImpurity(y, labels)
|
origGini, upperNode.LeftLabel = giniImpurity(y, labels)
|
||||||
} else if criterion == "entropy" {
|
} else if criterion == ENTROPY {
|
||||||
origGini, upperNode.LeftLabel = entropy(y, labels)
|
origGini, upperNode.LeftLabel = entropy(y, labels)
|
||||||
} else {
|
} else {
|
||||||
panic("Invalid impurity function, choose from GINI or ENTROPY")
|
panic("Invalid impurity function, choose from GINI or ENTROPY")
|
||||||
@ -271,10 +282,10 @@ func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []
|
|||||||
var leftLabels int64
|
var leftLabels int64
|
||||||
var rightLabels int64
|
var rightLabels int64
|
||||||
|
|
||||||
if criterion == "gini" {
|
if criterion == GINI {
|
||||||
leftGini, leftLabels = giniImpurity(lefty, labels)
|
leftGini, leftLabels = giniImpurity(lefty, labels)
|
||||||
rightGini, rightLabels = giniImpurity(righty, labels)
|
rightGini, rightLabels = giniImpurity(righty, labels)
|
||||||
} else if criterion == "entropy" {
|
} else if criterion == ENTROPY {
|
||||||
leftGini, leftLabels = entropy(lefty, labels)
|
leftGini, leftLabels = entropy(lefty, labels)
|
||||||
rightGini, rightLabels = entropy(righty, labels)
|
rightGini, rightLabels = entropy(righty, labels)
|
||||||
}
|
}
|
||||||
@ -336,7 +347,8 @@ func classifierBestSplit(tree CARTDecisionTreeClassifier, data [][]float64, y []
|
|||||||
return upperNode
|
return upperNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrintTree : this function prints out entire tree for visualization - visible to user
|
// String : this function prints out entire tree for visualization.
|
||||||
|
// Calls a recursive function to print the tree - classifierPrintTreeFromNode
|
||||||
func (tree *CARTDecisionTreeClassifier) String() string {
|
func (tree *CARTDecisionTreeClassifier) String() string {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
return classifierPrintTreeFromNode(rootNode, "")
|
return classifierPrintTreeFromNode(rootNode, "")
|
||||||
@ -377,6 +389,7 @@ func classifierPrintTreeFromNode(tree classifierNode, spacing string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Predict a single data point by traversing the entire tree
|
// Predict a single data point by traversing the entire tree
|
||||||
|
// Uses recursive logic to navigate the tree.
|
||||||
func classifierPredictSingle(tree classifierNode, instance []float64) int64 {
|
func classifierPredictSingle(tree classifierNode, instance []float64) int64 {
|
||||||
if instance[tree.Feature] < tree.Threshold {
|
if instance[tree.Feature] < tree.Threshold {
|
||||||
if tree.Left == nil {
|
if tree.Left == nil {
|
||||||
@ -393,14 +406,15 @@ func classifierPredictSingle(tree classifierNode, instance []float64) int64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict is visible to user. Given test data, they receive predictions for every datapoint.
|
// Given test data, return predictions for every datapoint. calls classifierPredictFromNode
|
||||||
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
func (tree *CARTDecisionTreeClassifier) Predict(X_test base.FixedDataGrid) []int64 {
|
||||||
root := *tree.RootNode
|
root := *tree.RootNode
|
||||||
test := classifierConvertInstancesToProblemVec(X_test)
|
test := classifierConvertInstancesToProblemVec(X_test)
|
||||||
return classifierPredictFromNode(root, test)
|
return classifierPredictFromNode(root, test)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function uses the rootnode from Predict. It is invisible to user, but called from predict method.
|
// This function uses the rootnode from Predict.
|
||||||
|
// It iterates through every data point and calls the recursive function to give predictions and then summarizes them.
|
||||||
func classifierPredictFromNode(tree classifierNode, test [][]float64) []int64 {
|
func classifierPredictFromNode(tree classifierNode, test [][]float64) []int64 {
|
||||||
var preds []int64
|
var preds []int64
|
||||||
for i := range test {
|
for i := range test {
|
||||||
@ -411,6 +425,8 @@ func classifierPredictFromNode(tree classifierNode, test [][]float64) []int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Given Test data and label, return the accuracy of the classifier.
|
// Given Test data and label, return the accuracy of the classifier.
|
||||||
|
// First it retreives predictions from the data, then compares for accuracy.
|
||||||
|
// Calls classifierEvaluateFromNode
|
||||||
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
|
func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float64 {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
xTest := classifierConvertInstancesToProblemVec(test)
|
xTest := classifierConvertInstancesToProblemVec(test)
|
||||||
@ -418,6 +434,7 @@ func (tree *CARTDecisionTreeClassifier) Evaluate(test base.FixedDataGrid) float6
|
|||||||
return classifierEvaluateFromNode(rootNode, xTest, yTest)
|
return classifierEvaluateFromNode(rootNode, xTest, yTest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Retrieve predictions and then calculate accuracy.
|
||||||
func classifierEvaluateFromNode(tree classifierNode, xTest [][]float64, yTest []int64) float64 {
|
func classifierEvaluateFromNode(tree classifierNode, xTest [][]float64, yTest []int64) float64 {
|
||||||
preds := classifierPredictFromNode(tree, xTest)
|
preds := classifierPredictFromNode(tree, xTest)
|
||||||
accuracy := 0.0
|
accuracy := 0.0
|
||||||
|
@ -10,9 +10,14 @@ import (
|
|||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The "r" prefix to all function names indicates that they were tailored to support regression.
|
const (
|
||||||
|
MAE string = "mae"
|
||||||
|
MSE string = "mse"
|
||||||
|
)
|
||||||
|
|
||||||
// RNode - Node struct for Decision Tree Regressor
|
// RNode - Node struct for Decision Tree Regressor
|
||||||
|
// It holds the information for each split
|
||||||
|
// Which feature to use, threshold, left prediction and right prediction
|
||||||
type regressorNode struct {
|
type regressorNode struct {
|
||||||
Left *regressorNode
|
Left *regressorNode
|
||||||
Right *regressorNode
|
Right *regressorNode
|
||||||
@ -24,6 +29,8 @@ type regressorNode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
// CARTDecisionTreeRegressor - Tree struct for Decision Tree Regressor
|
||||||
|
// It contains the rootNode, as well as the hyperparameters chosen by user.
|
||||||
|
// Also keeps track of splits used at tree level.
|
||||||
type CARTDecisionTreeRegressor struct {
|
type CARTDecisionTreeRegressor struct {
|
||||||
RootNode *regressorNode
|
RootNode *regressorNode
|
||||||
criterion string
|
criterion string
|
||||||
@ -74,7 +81,7 @@ func mseImpurity(y []float64) (float64, float64) {
|
|||||||
return meanSquaredError(y, yHat), yHat
|
return meanSquaredError(y, yHat), yHat
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split the data based on threshold and feature for testing information gain
|
// Split the data into left and right based on trehsold and feature.
|
||||||
func regressorCreateSplit(data [][]float64, feature int64, y []float64, threshold float64) ([][]float64, [][]float64, []float64, []float64) {
|
func regressorCreateSplit(data [][]float64, feature int64, y []float64, threshold float64) ([][]float64, [][]float64, []float64, []float64) {
|
||||||
var left [][]float64
|
var left [][]float64
|
||||||
var lefty []float64
|
var lefty []float64
|
||||||
@ -95,7 +102,8 @@ func regressorCreateSplit(data [][]float64, feature int64, y []float64, threshol
|
|||||||
return left, right, lefty, righty
|
return left, right, lefty, righty
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for finding unique values
|
// Helper function for finding unique values.
|
||||||
|
// Used for isolating unique values in a feature.
|
||||||
func regressorStringInSlice(a float64, list []float64) bool {
|
func regressorStringInSlice(a float64, list []float64) bool {
|
||||||
for _, b := range list {
|
for _, b := range list {
|
||||||
if b == a {
|
if b == a {
|
||||||
@ -105,7 +113,8 @@ func regressorStringInSlice(a float64, list []float64) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return only unique values of a feature
|
// Isolate only unique values.
|
||||||
|
// This way we can only try unique splits.
|
||||||
func regressorFindUnique(data []float64) []float64 {
|
func regressorFindUnique(data []float64) []float64 {
|
||||||
var unique []float64
|
var unique []float64
|
||||||
for i := range data {
|
for i := range data {
|
||||||
@ -116,7 +125,8 @@ func regressorFindUnique(data []float64) []float64 {
|
|||||||
return unique
|
return unique
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract out a single feature from data
|
// Extract out a single feature from data.
|
||||||
|
// Reduces complexity in managing splits and sorting
|
||||||
func regressorGetFeature(data [][]float64, feature int64) []float64 {
|
func regressorGetFeature(data [][]float64, feature int64) []float64 {
|
||||||
var featureVals []float64
|
var featureVals []float64
|
||||||
for i := range data {
|
for i := range data {
|
||||||
@ -125,7 +135,7 @@ func regressorGetFeature(data [][]float64, feature int64) []float64 {
|
|||||||
return featureVals
|
return featureVals
|
||||||
}
|
}
|
||||||
|
|
||||||
// Interface for creating new Decision Tree Regressor - cals rbestSplit()
|
// Interface for creating new Decision Tree Regressor
|
||||||
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTreeRegressor {
|
||||||
var tree CARTDecisionTreeRegressor
|
var tree CARTDecisionTreeRegressor
|
||||||
tree.maxDepth = maxDepth
|
tree.maxDepth = maxDepth
|
||||||
@ -134,6 +144,7 @@ func NewDecisionTreeRegressor(criterion string, maxDepth int64) *CARTDecisionTre
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate that the split being tested has not been done before.
|
// Validate that the split being tested has not been done before.
|
||||||
|
// This prevents redundant splits from hapenning.
|
||||||
func regressorValidate(triedSplits [][]float64, feature int64, threshold float64) bool {
|
func regressorValidate(triedSplits [][]float64, feature int64, threshold float64) bool {
|
||||||
for i := range triedSplits {
|
for i := range triedSplits {
|
||||||
split := triedSplits[i]
|
split := triedSplits[i]
|
||||||
@ -146,6 +157,7 @@ func regressorValidate(triedSplits [][]float64, feature int64, threshold float64
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Re order data based on a feature for optimizing code
|
// Re order data based on a feature for optimizing code
|
||||||
|
// Helps in updating splits without reiterating entire dataset
|
||||||
func regressorReOrderData(featureVal []float64, data [][]float64, y []float64) ([][]float64, []float64) {
|
func regressorReOrderData(featureVal []float64, data [][]float64, y []float64) ([][]float64, []float64) {
|
||||||
s := NewSlice(featureVal)
|
s := NewSlice(featureVal)
|
||||||
sort.Sort(s)
|
sort.Sort(s)
|
||||||
@ -176,7 +188,8 @@ func regressorUpdateSplit(left [][]float64, lefty []float64, right [][]float64,
|
|||||||
return left, lefty, right, righty
|
return left, lefty, right, righty
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extra Method for creating simple to use interface. Many params are either redundant for user but are needed only for recursive logic.
|
// Fit - Build the tree using the data
|
||||||
|
// Creates empty root node and builds tree by calling regressorBestSplit
|
||||||
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
|
func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
|
||||||
var emptyNode regressorNode
|
var emptyNode regressorNode
|
||||||
data := regressorConvertInstancesToProblemVec(X)
|
data := regressorConvertInstancesToProblemVec(X)
|
||||||
@ -187,7 +200,8 @@ func (tree *CARTDecisionTreeRegressor) Fit(X base.FixedDataGrid) {
|
|||||||
tree.RootNode = &emptyNode
|
tree.RootNode = &emptyNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Essentially the Fit Method - Impelements recursive logic
|
// Builds the tree by iteratively finding the best split.
|
||||||
|
// Recursive function - stops if maxDepth is reached or nodes are pure
|
||||||
func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode regressorNode, criterion string, maxDepth int64, depth int64) regressorNode {
|
func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []float64, upperNode regressorNode, criterion string, maxDepth int64, depth int64) regressorNode {
|
||||||
|
|
||||||
depth++
|
depth++
|
||||||
@ -200,10 +214,12 @@ func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []fl
|
|||||||
var bestLoss float64
|
var bestLoss float64
|
||||||
var origLoss float64
|
var origLoss float64
|
||||||
|
|
||||||
if criterion == "mae" {
|
if criterion == MAE {
|
||||||
origLoss, upperNode.LeftPred = maeImpurity(y)
|
origLoss, upperNode.LeftPred = maeImpurity(y)
|
||||||
} else {
|
} else if criterion == MSE {
|
||||||
origLoss, upperNode.LeftPred = mseImpurity(y)
|
origLoss, upperNode.LeftPred = mseImpurity(y)
|
||||||
|
} else {
|
||||||
|
panic("Invalid impurity function, choose from MAE or MSE")
|
||||||
}
|
}
|
||||||
|
|
||||||
bestLoss = origLoss
|
bestLoss = origLoss
|
||||||
@ -252,10 +268,10 @@ func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []fl
|
|||||||
var leftPred float64
|
var leftPred float64
|
||||||
var rightPred float64
|
var rightPred float64
|
||||||
|
|
||||||
if criterion == "mae" {
|
if criterion == MAE {
|
||||||
leftLoss, leftPred = maeImpurity(lefty)
|
leftLoss, leftPred = maeImpurity(lefty)
|
||||||
rightLoss, rightPred = maeImpurity(righty)
|
rightLoss, rightPred = maeImpurity(righty)
|
||||||
} else {
|
} else if criterion == MSE {
|
||||||
leftLoss, leftPred = mseImpurity(lefty)
|
leftLoss, leftPred = mseImpurity(lefty)
|
||||||
rightLoss, rightPred = mseImpurity(righty)
|
rightLoss, rightPred = mseImpurity(righty)
|
||||||
}
|
}
|
||||||
@ -312,12 +328,13 @@ func regressorBestSplit(tree CARTDecisionTreeRegressor, data [][]float64, y []fl
|
|||||||
return upperNode
|
return upperNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print Tree for Visualtion - calls printTreeFromNode()
|
// Print Tree for Visualtion - calls regressorPrintTreeFromNode()
|
||||||
func (tree *CARTDecisionTreeRegressor) String() string {
|
func (tree *CARTDecisionTreeRegressor) String() string {
|
||||||
rootNode := *tree.RootNode
|
rootNode := *tree.RootNode
|
||||||
return regressorPrintTreeFromNode(rootNode, "")
|
return regressorPrintTreeFromNode(rootNode, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recursively explore the entire tree and print out all details such as threshold, feature, prediction
|
||||||
func regressorPrintTreeFromNode(tree regressorNode, spacing string) string {
|
func regressorPrintTreeFromNode(tree regressorNode, spacing string) string {
|
||||||
returnString := ""
|
returnString := ""
|
||||||
returnString += spacing + "Feature "
|
returnString += spacing + "Feature "
|
||||||
@ -353,7 +370,8 @@ func regressorPrintTreeFromNode(tree regressorNode, spacing string) string {
|
|||||||
return returnString
|
return returnString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict a single data point
|
// Predict a single data point by navigating to rootNodes.
|
||||||
|
// Uses a recursive logic
|
||||||
func regressorPredictSingle(tree regressorNode, instance []float64) float64 {
|
func regressorPredictSingle(tree regressorNode, instance []float64) float64 {
|
||||||
if instance[tree.Feature] < tree.Threshold {
|
if instance[tree.Feature] < tree.Threshold {
|
||||||
if tree.Left == nil {
|
if tree.Left == nil {
|
||||||
@ -370,14 +388,16 @@ func regressorPredictSingle(tree regressorNode, instance []float64) float64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predict method for multiple data points. Calls predictFromNode()
|
// Predict method for multiple data points.
|
||||||
|
// First converts input data into usable format, and then calls regressorPredictFromNode
|
||||||
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
func (tree *CARTDecisionTreeRegressor) Predict(X_test base.FixedDataGrid) []float64 {
|
||||||
root := *tree.RootNode
|
root := *tree.RootNode
|
||||||
test := regressorConvertInstancesToProblemVec(X_test)
|
test := regressorConvertInstancesToProblemVec(X_test)
|
||||||
return regressorPredictFromNode(root, test)
|
return regressorPredictFromNode(root, test)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use tree's root node to print out entire tree
|
// Use tree's root node to print out entire tree.
|
||||||
|
// Iterates over all data points and calls regressorPredictSingle to predict individual datapoints.
|
||||||
func regressorPredictFromNode(tree regressorNode, test [][]float64) []float64 {
|
func regressorPredictFromNode(tree regressorNode, test [][]float64) []float64 {
|
||||||
var preds []float64
|
var preds []float64
|
||||||
for i := range test {
|
for i := range test {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user