1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Renaming Impurity Functions

This commit is contained in:
Ayush 2020-07-30 11:21:06 +05:30
parent 1954aae7a6
commit d587340e4a
2 changed files with 8 additions and 8 deletions

View File

@ -39,7 +39,7 @@ type CARTDecisionTreeClassifier struct {
}
// Calculate Gini Impurity of Target Labels
func giniImpurity(y []int64, labels []int64) (float64, int64) {
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
nInstances := len(y)
gini := 0.0
maxLabelCount := 0
@ -62,7 +62,7 @@ func giniImpurity(y []int64, labels []int64) (float64, int64) {
}
// Calculate Entropy loss of Target Labels
func entropy(y []int64, labels []int64) (float64, int64) {
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
nInstances := len(y)
entropy := 0.0
maxLabelCount := 0
@ -91,9 +91,9 @@ func entropy(y []int64, labels []int64) (float64, int64) {
func calculateClassificationLoss(y []int64, labels []int64, criterion string) (float64, int64) {
if criterion == GINI {
return giniImpurity(y, labels)
return computeGiniImpurityAndModeLabel(y, labels)
} else if criterion == ENTROPY {
return entropy(y, labels)
return computeEntropyAndModeLabel(y, labels)
} else {
panic("Invalid impurity function, choose from GINI or ENTROPY")
}

View File

@ -59,7 +59,7 @@ func meanAbsoluteError(y []float64, yBar float64) float64 {
}
// Turn Mean Absolute Error into impurity function for decision trees.
func maeImpurity(y []float64) (float64, float64) {
func computeMaeImpurityAndAverage(y []float64) (float64, float64) {
yHat := average(y)
return meanAbsoluteError(y, yHat), yHat
}
@ -76,16 +76,16 @@ func meanSquaredError(y []float64, yBar float64) float64 {
}
// Convert mean squared error into impurity function for decision trees
func mseImpurity(y []float64) (float64, float64) {
func computeMseImpurityAndAverage(y []float64) (float64, float64) {
yHat := average(y)
return meanSquaredError(y, yHat), yHat
}
func calculateRegressionLoss(y []float64, criterion string) (float64, float64) {
if criterion == MAE {
return maeImpurity(y)
return computeMaeImpurityAndAverage(y)
} else if criterion == MSE {
return mseImpurity(y)
return computeMseImpurityAndAverage(y)
} else {
panic("Invalid impurity function, choose from MAE or MSE")
}