1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

Optimizing Loss Calculation

This commit is contained in:
Ayush 2020-08-01 11:25:53 +05:30
parent ae2338c2c1
commit 9d1ac82a40
4 changed files with 43 additions and 49 deletions

View File

@ -35,10 +35,13 @@ func main() {
// Create New Classification Tree
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
decTree := NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
// Train Tree
decTree.Fit(trainData)
err = decTree.Fit(trainData)
if err != nil {
panic(err)
}
// Print out tree for visualization - shows splits and feature and predictions
fmt.Println(decTree.String())
@ -62,7 +65,10 @@ func main() {
regTree := NewDecisionTreeRegressor("mse", -1)
// Train Tree
regTree.Fit(trainRegData)
err = regTree.Fit(trainRegData)
if err != nil {
panic(err)
}
// Print out tree for visualization
fmt.Println(regTree.String())

View File

@ -39,25 +39,31 @@ type CARTDecisionTreeClassifier struct {
triedSplits [][]float64
}
// Convert a series of labels to frequency map for efficient impurity calculation
func convertToMap(y []int64, labels []int64) map[int64]int {
labelCount := make(map[int64]int)
for _, label := range labels {
labelCount[label] = 0
}
for _, value := range y {
labelCount[value]++
}
return labelCount
}
// Calculate Gini Impurity of Target Labels
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
nInstances := len(y)
gini := 0.0
maxLabelCount := 0
var maxLabel int64 = 0
for label := range labels {
numLabel := 0
for target := range y {
if y[target] == labels[label] {
numLabel++
}
labelCount := convertToMap(y, labels)
for _, label := range labels {
if labelCount[label] > labelCount[maxLabel] {
maxLabel = label
}
p := float64(numLabel) / float64(nInstances)
p := float64(labelCount[label]) / float64(nInstances)
gini += p * (1 - p)
if numLabel > maxLabelCount {
maxLabel = labels[label]
maxLabelCount = numLabel
}
}
return gini, maxLabel
}
@ -66,26 +72,19 @@ func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64)
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
nInstances := len(y)
entropy := 0.0
maxLabelCount := 0
var maxLabel int64 = 0
for label := range labels {
numLabel := 0
for target := range y {
if y[target] == labels[label] {
numLabel++
}
}
p := float64(numLabel) / float64(nInstances)
labelCount := convertToMap(y, labels)
for _, label := range labels {
if labelCount[label] > labelCount[maxLabel] {
maxLabel = label
}
p := float64(labelCount[label]) / float64(nInstances)
logP := math.Log2(p)
if p == 0 {
logP = 0
}
entropy += -p * logP
if numLabel > maxLabelCount {
maxLabel = labels[label]
maxLabelCount = numLabel
}
entropy += (-p * logP)
}
return entropy, maxLabel
}

View File

@ -1,7 +1,6 @@
package trees
import (
"fmt"
"testing"
. "github.com/smartystreets/goconvey/convey"
@ -42,8 +41,7 @@ func TestRegressor(t *testing.T) {
// is data reordered correctly
orderedData, orderedY := classifierReOrderData(getFeature(classifierData, 1), classifierData, classifiery)
fmt.Println(orderedData)
fmt.Println(orderedY)
So(orderedData[1][1], ShouldEqual, 3.0)
So(orderedY[0], ShouldEqual, 1)
@ -81,9 +79,9 @@ func TestRegressor(t *testing.T) {
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
So(len(leftData), ShouldEqual, 2)
So(len(lefty), ShouldEqual, 2)
So(len(leftY), ShouldEqual, 2)
So(len(rightData), ShouldEqual, 2)
So(len(righty), ShouldEqual, 2)
So(len(rightY), ShouldEqual, 2)
// is data reordered correctly
regressorOrderedData, regressorOrderedY := regressorReOrderData(getFeature(data, 1), data, y)

View File

@ -4,23 +4,14 @@ import (
"github.com/sjwhitworth/golearn/base"
)
// Helper Function to check if data point is unique or not.
// We will use this to isolate unique values of a feature
func stringInSlice(a float64, list []float64) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// Isolate only unique values. This way, we can try only unique splits and not redundant ones.
func findUnique(data []float64) []float64 {
var unique []float64
for i := range data {
if !stringInSlice(data[i], unique) {
unique = append(unique, data[i])
keys := make(map[float64]bool)
unique := []float64{}
for _, entry := range data {
if _, value := keys[entry]; !value {
keys[entry] = true
unique = append(unique, entry)
}
}
return unique