mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Optimizing Loss Calculation
This commit is contained in:
parent
ae2338c2c1
commit
9d1ac82a40
@ -35,10 +35,13 @@ func main() {
|
|||||||
|
|
||||||
// Create New Classification Tree
|
// Create New Classification Tree
|
||||||
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
|
// Hyperparameters - loss function, max Depth (-1 will split until pure), list of unique labels
|
||||||
decTree = NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
|
decTree := NewDecisionTreeClassifier("entropy", -1, []int64{0, 1})
|
||||||
|
|
||||||
// Train Tree
|
// Train Tree
|
||||||
decTree.Fit(trainData)
|
err = decTree.Fit(trainData)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
// Print out tree for visualization - shows splits and feature and predictions
|
// Print out tree for visualization - shows splits and feature and predictions
|
||||||
fmt.Println(decTree.String())
|
fmt.Println(decTree.String())
|
||||||
|
|
||||||
@ -62,7 +65,10 @@ func main() {
|
|||||||
regTree := NewDecisionTreeRegressor("mse", -1)
|
regTree := NewDecisionTreeRegressor("mse", -1)
|
||||||
|
|
||||||
// Train Tree
|
// Train Tree
|
||||||
regTree.Fit(trainRegData)
|
err = regTree.Fit(trainRegData)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Print out tree for visualization
|
// Print out tree for visualization
|
||||||
fmt.Println(regTree.String())
|
fmt.Println(regTree.String())
|
||||||
|
@ -39,25 +39,31 @@ type CARTDecisionTreeClassifier struct {
|
|||||||
triedSplits [][]float64
|
triedSplits [][]float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a series of labels to frequency map for efficient impurity calculation
|
||||||
|
func convertToMap(y []int64, labels []int64) map[int64]int {
|
||||||
|
labelCount := make(map[int64]int)
|
||||||
|
for _, label := range labels {
|
||||||
|
labelCount[label] = 0
|
||||||
|
}
|
||||||
|
for _, value := range y {
|
||||||
|
labelCount[value]++
|
||||||
|
}
|
||||||
|
return labelCount
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate Gini Impurity of Target Labels
|
// Calculate Gini Impurity of Target Labels
|
||||||
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
||||||
nInstances := len(y)
|
nInstances := len(y)
|
||||||
gini := 0.0
|
gini := 0.0
|
||||||
maxLabelCount := 0
|
|
||||||
var maxLabel int64 = 0
|
var maxLabel int64 = 0
|
||||||
for label := range labels {
|
|
||||||
numLabel := 0
|
labelCount := convertToMap(y, labels)
|
||||||
for target := range y {
|
for _, label := range labels {
|
||||||
if y[target] == labels[label] {
|
if labelCount[label] > labelCount[maxLabel] {
|
||||||
numLabel++
|
maxLabel = label
|
||||||
}
|
|
||||||
}
|
}
|
||||||
p := float64(numLabel) / float64(nInstances)
|
p := float64(labelCount[label]) / float64(nInstances)
|
||||||
gini += p * (1 - p)
|
gini += p * (1 - p)
|
||||||
if numLabel > maxLabelCount {
|
|
||||||
maxLabel = labels[label]
|
|
||||||
maxLabelCount = numLabel
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return gini, maxLabel
|
return gini, maxLabel
|
||||||
}
|
}
|
||||||
@ -66,26 +72,19 @@ func computeGiniImpurityAndModeLabel(y []int64, labels []int64) (float64, int64)
|
|||||||
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
func computeEntropyAndModeLabel(y []int64, labels []int64) (float64, int64) {
|
||||||
nInstances := len(y)
|
nInstances := len(y)
|
||||||
entropy := 0.0
|
entropy := 0.0
|
||||||
maxLabelCount := 0
|
|
||||||
var maxLabel int64 = 0
|
var maxLabel int64 = 0
|
||||||
for label := range labels {
|
|
||||||
numLabel := 0
|
|
||||||
for target := range y {
|
|
||||||
if y[target] == labels[label] {
|
|
||||||
numLabel++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p := float64(numLabel) / float64(nInstances)
|
|
||||||
|
|
||||||
|
labelCount := convertToMap(y, labels)
|
||||||
|
for _, label := range labels {
|
||||||
|
if labelCount[label] > labelCount[maxLabel] {
|
||||||
|
maxLabel = label
|
||||||
|
}
|
||||||
|
p := float64(labelCount[label]) / float64(nInstances)
|
||||||
logP := math.Log2(p)
|
logP := math.Log2(p)
|
||||||
if p == 0 {
|
if p == 0 {
|
||||||
logP = 0
|
logP = 0
|
||||||
}
|
}
|
||||||
entropy += -p * logP
|
entropy += (-p * logP)
|
||||||
if numLabel > maxLabelCount {
|
|
||||||
maxLabel = labels[label]
|
|
||||||
maxLabelCount = numLabel
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return entropy, maxLabel
|
return entropy, maxLabel
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package trees
|
package trees
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
@ -42,8 +41,7 @@ func TestRegressor(t *testing.T) {
|
|||||||
|
|
||||||
// is data reordered correctly
|
// is data reordered correctly
|
||||||
orderedData, orderedY := classifierReOrderData(getFeature(classifierData, 1), classifierData, classifiery)
|
orderedData, orderedY := classifierReOrderData(getFeature(classifierData, 1), classifierData, classifiery)
|
||||||
fmt.Println(orderedData)
|
|
||||||
fmt.Println(orderedY)
|
|
||||||
So(orderedData[1][1], ShouldEqual, 3.0)
|
So(orderedData[1][1], ShouldEqual, 3.0)
|
||||||
So(orderedY[0], ShouldEqual, 1)
|
So(orderedY[0], ShouldEqual, 1)
|
||||||
|
|
||||||
@ -81,9 +79,9 @@ func TestRegressor(t *testing.T) {
|
|||||||
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
|
leftData, rightData, leftY, rightY := regressorCreateSplit(data, 1, y, 5.0)
|
||||||
|
|
||||||
So(len(leftData), ShouldEqual, 2)
|
So(len(leftData), ShouldEqual, 2)
|
||||||
So(len(lefty), ShouldEqual, 2)
|
So(len(leftY), ShouldEqual, 2)
|
||||||
So(len(rightData), ShouldEqual, 2)
|
So(len(rightData), ShouldEqual, 2)
|
||||||
So(len(righty), ShouldEqual, 2)
|
So(len(rightY), ShouldEqual, 2)
|
||||||
|
|
||||||
// is data reordered correctly
|
// is data reordered correctly
|
||||||
regressorOrderedData, regressorOrderedY := regressorReOrderData(getFeature(data, 1), data, y)
|
regressorOrderedData, regressorOrderedY := regressorReOrderData(getFeature(data, 1), data, y)
|
||||||
|
@ -4,23 +4,14 @@ import (
|
|||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper Function to check if data point is unique or not.
|
|
||||||
// We will use this to isolate unique values of a feature
|
|
||||||
func stringInSlice(a float64, list []float64) bool {
|
|
||||||
for _, b := range list {
|
|
||||||
if b == a {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Isolate only unique values. This way, we can try only unique splits and not redundant ones.
|
// Isolate only unique values. This way, we can try only unique splits and not redundant ones.
|
||||||
func findUnique(data []float64) []float64 {
|
func findUnique(data []float64) []float64 {
|
||||||
var unique []float64
|
keys := make(map[float64]bool)
|
||||||
for i := range data {
|
unique := []float64{}
|
||||||
if !stringInSlice(data[i], unique) {
|
for _, entry := range data {
|
||||||
unique = append(unique, data[i])
|
if _, value := keys[entry]; !value {
|
||||||
|
keys[entry] = true
|
||||||
|
unique = append(unique, entry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return unique
|
return unique
|
||||||
|
Loading…
x
Reference in New Issue
Block a user