mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Cross-fold validation
This commit is contained in:
parent
fcb96f1fad
commit
b2f5b2840d
82
evaluation/cross_fold.go
Normal file
82
evaluation/cross_fold.go
Normal file
@ -0,0 +1,82 @@
|
||||
package evaluation
|
||||
|
||||
import (
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// GetCrossValidatedMetric returns the mean and variance of the confusion-matrix-derived
|
||||
// metric across all folds.
|
||||
func GetCrossValidatedMetric(in []ConfusionMatrix, metric func(ConfusionMatrix) float64) (mean, variance float64) {
|
||||
scores := make([]float64, len(in))
|
||||
for i, c := range in {
|
||||
scores[i] = metric(c)
|
||||
}
|
||||
|
||||
// Compute mean, variance
|
||||
sum := 0.0
|
||||
for _, s := range scores {
|
||||
sum += s
|
||||
}
|
||||
sum /= float64(len(scores))
|
||||
mean = sum
|
||||
sum = 0.0
|
||||
for _, s := range scores {
|
||||
sum += (s - mean) * (s - mean)
|
||||
}
|
||||
sum /= float64(len(scores))
|
||||
variance = sum
|
||||
return mean, variance
|
||||
}
|
||||
|
||||
// GenerateCrossFoldValidationConfusionMatrices divides the data into a number of folds
|
||||
// then trains and evaluates the classifier on each fold, producing a new ConfusionMatrix.
|
||||
func GenerateCrossFoldValidationConfusionMatrices(data base.FixedDataGrid, cls base.Classifier, folds int) ([]ConfusionMatrix, error) {
|
||||
_, rows := data.Size()
|
||||
|
||||
// Assign each row to a fold
|
||||
foldMap := make([]int, rows)
|
||||
inverseFoldMap := make(map[int][]int)
|
||||
for i := 0; i < rows; i++ {
|
||||
fold := rand.Intn(folds)
|
||||
foldMap[i] = fold
|
||||
if _, ok := inverseFoldMap[fold]; !ok {
|
||||
inverseFoldMap[fold] = make([]int, 0)
|
||||
}
|
||||
inverseFoldMap[fold] = append(inverseFoldMap[fold], i)
|
||||
}
|
||||
|
||||
ret := make([]ConfusionMatrix, folds)
|
||||
|
||||
// Create training/test views for each fold
|
||||
for i := 0; i < folds; i++ {
|
||||
// Fold i is for testing
|
||||
testData := base.NewInstancesViewFromVisible(data, inverseFoldMap[i], data.AllAttributes())
|
||||
otherRows := make([]int, 0)
|
||||
for j := 0; j < folds; j++ {
|
||||
if i == j {
|
||||
continue
|
||||
}
|
||||
otherRows = append(otherRows, inverseFoldMap[j]...)
|
||||
}
|
||||
trainData := base.NewInstancesViewFromVisible(data, otherRows, data.AllAttributes())
|
||||
// Train
|
||||
err := cls.Fit(trainData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Predict
|
||||
pred, err := cls.Predict(testData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Evaluate
|
||||
cf, err := GetConfusionMatrix(testData, pred)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret[i] = cf
|
||||
}
|
||||
return ret, nil
|
||||
|
||||
}
|
40
examples/crossfold/rf.go
Normal file
40
examples/crossfold/rf.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Demonstrates decision tree classification
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/ensemble"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"math"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
var tree base.Classifier
|
||||
|
||||
// Load in the iris dataset
|
||||
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 1; i < 60; i += 2 {
|
||||
// Demonstrate the effect of adding more trees to the forest
|
||||
// and also how much better it is without discretisation.
|
||||
rand.Seed(44111342)
|
||||
|
||||
tree = ensemble.NewRandomForest(i, 4)
|
||||
cfs, err := evaluation.GenerateCrossFoldValidationConfusionMatrices(iris, tree, 5)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mean, variance := evaluation.GetCrossValidatedMetric(cfs, evaluation.GetAccuracy)
|
||||
stdev := math.Sqrt(variance)
|
||||
|
||||
fmt.Printf("%d\t%.2f\t(+/- %.2f)\n", i, mean, stdev*2)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user