1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #90 from Sentimentron/cross-fold-staging

Cross-fold validation
This commit is contained in:
Stephen Whitworth 2014-11-21 13:53:29 +00:00
commit 9c7049ba89
2 changed files with 122 additions and 0 deletions

82
evaluation/cross_fold.go Normal file
View File

@ -0,0 +1,82 @@
package evaluation
import (
"github.com/sjwhitworth/golearn/base"
"math/rand"
)
// GetCrossValidatedMetric returns the mean and variance of the confusion-matrix-derived
// metric across all folds.
func GetCrossValidatedMetric(in []ConfusionMatrix, metric func(ConfusionMatrix) float64) (mean, variance float64) {
scores := make([]float64, len(in))
for i, c := range in {
scores[i] = metric(c)
}
// Compute mean, variance
sum := 0.0
for _, s := range scores {
sum += s
}
sum /= float64(len(scores))
mean = sum
sum = 0.0
for _, s := range scores {
sum += (s - mean) * (s - mean)
}
sum /= float64(len(scores))
variance = sum
return mean, variance
}
// GenerateCrossFoldValidationConfusionMatrices divides the data into a number of folds
// then trains and evaluates the classifier on each fold, producing a new ConfusionMatrix.
func GenerateCrossFoldValidationConfusionMatrices(data base.FixedDataGrid, cls base.Classifier, folds int) ([]ConfusionMatrix, error) {
_, rows := data.Size()
// Assign each row to a fold
foldMap := make([]int, rows)
inverseFoldMap := make(map[int][]int)
for i := 0; i < rows; i++ {
fold := rand.Intn(folds)
foldMap[i] = fold
if _, ok := inverseFoldMap[fold]; !ok {
inverseFoldMap[fold] = make([]int, 0)
}
inverseFoldMap[fold] = append(inverseFoldMap[fold], i)
}
ret := make([]ConfusionMatrix, folds)
// Create training/test views for each fold
for i := 0; i < folds; i++ {
// Fold i is for testing
testData := base.NewInstancesViewFromVisible(data, inverseFoldMap[i], data.AllAttributes())
otherRows := make([]int, 0)
for j := 0; j < folds; j++ {
if i == j {
continue
}
otherRows = append(otherRows, inverseFoldMap[j]...)
}
trainData := base.NewInstancesViewFromVisible(data, otherRows, data.AllAttributes())
// Train
err := cls.Fit(trainData)
if err != nil {
return nil, err
}
// Predict
pred, err := cls.Predict(testData)
if err != nil {
return nil, err
}
// Evaluate
cf, err := GetConfusionMatrix(testData, pred)
if err != nil {
return nil, err
}
ret[i] = cf
}
return ret, nil
}

40
examples/crossfold/rf.go Normal file
View File

@ -0,0 +1,40 @@
// Demonstrates decision tree classification
package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/ensemble"
"github.com/sjwhitworth/golearn/evaluation"
"math"
"math/rand"
)
func main() {
var tree base.Classifier
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
for i := 1; i < 60; i += 2 {
// Demonstrate the effect of adding more trees to the forest
// and also how much better it is without discretisation.
rand.Seed(44111342)
tree = ensemble.NewRandomForest(i, 4)
cfs, err := evaluation.GenerateCrossFoldValidationConfusionMatrices(iris, tree, 5)
if err != nil {
panic(err)
}
mean, variance := evaluation.GetCrossValidatedMetric(cfs, evaluation.GetAccuracy)
stdev := math.Sqrt(variance)
fmt.Printf("%d\t%.2f\t(+/- %.2f)\n", i, mean, stdev*2)
}
}