diff --git a/trees/cart_classifier.go b/trees/cart_classifier.go index 17a3ee7..fee9043 100644 --- a/trees/cart_classifier.go +++ b/trees/cart_classifier.go @@ -39,7 +39,7 @@ type CARTDecisionTreeClassifier struct { } // Calculate Gini Impurity of Target Labels -func giniImpurity(y []int64, labels []int64) (float64, int64) { +func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) { nInstances := len(y) gini := 0.0 maxLabelCount := 0 @@ -62,7 +62,7 @@ func giniImpurity(y []int64, labels []int64) (float64, int64) { } // Calculate Entropy loss of Target Labels -func entropy(y []int64, labels []int64) (float64, int64) { +func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) { nInstances := len(y) entropy := 0.0 maxLabelCount := 0 @@ -91,9 +91,9 @@ func entropy(y []int64, labels []int64) (float64, int64) { func calculateClassificationLoss(y []int64, labels []int64, criterion string) (float64, int64) { if criterion == GINI { - return giniImpurity(y, labels) + return computeGiniImpurityAndModeLabel(y, labels) } else if criterion == ENTROPY { - return entropy(y, labels) + return computeEntropyAndModeLabel(y, labels) } else { panic("Invalid impurity function, choose from GINI or ENTROPY") } diff --git a/trees/cart_regressor.go b/trees/cart_regressor.go index b94da1d..3509a15 100644 --- a/trees/cart_regressor.go +++ b/trees/cart_regressor.go @@ -59,7 +59,7 @@ func meanAbsoluteError(y []float64, yBar float64) float64 { } // Turn Mean Absolute Error into impurity function for decision trees. -func maeImpurity(y []float64) (float64, float64) { +func computeMaeImpurityAndAverage(y []float64) (float64, float64) { yHat := average(y) return meanAbsoluteError(y, yHat), yHat } @@ -76,16 +76,16 @@ func meanSquaredError(y []float64, yBar float64) float64 { } // Convert mean squared error into impurity function for decision trees -func mseImpurity(y []float64) (float64, float64) { +func computeMseImpurityAndAverage(y []float64) (float64, float64) { yHat := average(y) return meanSquaredError(y, yHat), yHat } func calculateRegressionLoss(y []float64, criterion string) (float64, float64) { if criterion == MAE { - return maeImpurity(y) + return computeMaeImpurityAndAverage(y) } else if criterion == MSE { - return mseImpurity(y) + return computeMseImpurityAndAverage(y) } else { panic("Invalid impurity function, choose from MAE or MSE") }