From 9d1ac82a40d6141e1bd0cdd6d1dd68bc430d981b Mon Sep 17 00:00:00 2001 From: Ayush Date: Sat, 1 Aug 2020 11:25:53 +0530 Subject: [PATCH] Optimizing Loss Calculation --- examples/trees/cart.go | 12 +++++++--- trees/cart_classifier.go | 51 ++++++++++++++++++++-------------------- trees/cart_test.go | 8 +++---- trees/cart_utils.go | 21 +++++------------ 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/examples/trees/cart.go b/examples/trees/cart.go index f465d54..a6fc909 100644 --- a/examples/trees/cart.go +++ b/examples/trees/cart.go @@ -35,10 +35,13 @@ func main() { // Create New Classification Tree // Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels - decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1}) + decTree := NewDecisionTreeClassifier("entropy", -1, []int64{0, 1}) // Train Tree - decTree.Fit(trainData) + err = decTree.Fit(trainData) + if err != nil { + panic(err) + } // Print out tree for visualization - shows splits and feature and predictions fmt.Println(decTree.String()) @@ -62,7 +65,10 @@ func main() { regTree := NewDecisionTreeRegressor("mse", -1) // Train Tree - regTree.Fit(trainRegData) + err = regTree.Fit(trainRegData) + if err != nil { + panic(err) + } // Print out tree for visualization fmt.Println(regTree.String()) diff --git a/trees/cart_classifier.go b/trees/cart_classifier.go index 828f2dc..bb9af51 100644 --- a/trees/cart_classifier.go +++ b/trees/cart_classifier.go @@ -39,25 +39,31 @@ type CARTDecisionTreeClassifier struct { triedSplits [][]float64 } +// Convert a series of labels to frequency map for efficient impurity calculation +func convertToMap(y []int64, labels []int64) map[int64]int { + labelCount := make(map[int64]int) + for _, label := range labels { + labelCount[label] = 0 + } + for _, value := range y { + labelCount[value]++ + } + return labelCount +} + // Calculate Gini Impurity of Target Labels func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) { nInstances := len(y) gini := 0.0 - maxLabelCount := 0 var maxLabel int64 = 0 - for label := range labels { - numLabel := 0 - for target := range y { - if y[target] == labels[label] { - numLabel++ - } + + labelCount := convertToMap(y, labels) + for _, label := range labels { + if labelCount[label] > labelCount[maxLabel] { + maxLabel = label } - p := float64(numLabel) / float64(nInstances) + p := float64(labelCount[label]) / float64(nInstances) gini += p * (1 - p) - if numLabel > maxLabelCount { - maxLabel = labels[label] - maxLabelCount = numLabel - } } return gini, maxLabel } @@ -66,26 +72,19 @@ func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) { nInstances := len(y) entropy := 0.0 - maxLabelCount := 0 var maxLabel int64 = 0 - for label := range labels { - numLabel := 0 - for target := range y { - if y[target] == labels[label] { - numLabel++ - } - } - p := float64(numLabel) / float64(nInstances) + labelCount := convertToMap(y, labels) + for _, label := range labels { + if labelCount[label] > labelCount[maxLabel] { + maxLabel = label + } + p := float64(labelCount[label]) / float64(nInstances) logP := math.Log2(p) if p == 0 { logP = 0 } - entropy += -p * logP - if numLabel > maxLabelCount { - maxLabel = labels[label] - maxLabelCount = numLabel - } + entropy += (-p * logP) } return entropy, maxLabel } diff --git a/trees/cart_test.go b/trees/cart_test.go index 99374c8..3edee6d 100644 --- a/trees/cart_test.go +++ b/trees/cart_test.go @@ -1,7 +1,6 @@ package trees import ( - "fmt" "testing" . "github.com/smartystreets/goconvey/convey" @@ -42,8 +41,7 @@ func TestRegressor(t *testing.T) { // is data reordered correctly orderedData, orderedY := classifierReOrderData(getFeature(classifierData, 1), classifierData, classifiery) - fmt.Println(orderedData) - fmt.Println(orderedY) + So(orderedData[1][1], ShouldEqual, 3.0) So(orderedY[0], ShouldEqual, 1) @@ -81,9 +79,9 @@ func TestRegressor(t *testing.T) { leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0) So(len(leftData), ShouldEqual, 2) - So(len(lefty), ShouldEqual, 2) + So(len(leftY), ShouldEqual, 2) So(len(rightData), ShouldEqual, 2) - So(len(righty), ShouldEqual, 2) + So(len(rightY), ShouldEqual, 2) // is data reordered correctly regressorOrderedData, regressorOrderedY := regressorReOrderData(getFeature(data, 1), data, y) diff --git a/trees/cart_utils.go b/trees/cart_utils.go index d3b9b4a..251dee9 100644 --- a/trees/cart_utils.go +++ b/trees/cart_utils.go @@ -4,23 +4,14 @@ import ( "github.com/sjwhitworth/golearn/base" ) -// Helper Function to check if data point is unique or not. -// We will use this to isolate unique values of a feature -func stringInSlice(a float64, list []float64) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - // Isolate only unique values. This way, we can try only unique splits and not redundant ones. func findUnique(data []float64) []float64 { - var unique []float64 - for i := range data { - if !stringInSlice(data[i], unique) { - unique = append(unique, data[i]) + keys := make(map[float64]bool) + unique := []float64{} + for _, entry := range data { + if _, value := keys[entry]; !value { + keys[entry] = true + unique = append(unique, entry) } } return unique