mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Cross-fold validation
This commit is contained in:
parent
fcb96f1fad
commit
b2f5b2840d
82
evaluation/cross_fold.go
Normal file
82
evaluation/cross_fold.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package evaluation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCrossValidatedMetric returns the mean and variance of the confusion-matrix-derived
|
||||||
|
// metric across all folds.
|
||||||
|
func GetCrossValidatedMetric(in []ConfusionMatrix, metric func(ConfusionMatrix) float64) (mean, variance float64) {
|
||||||
|
scores := make([]float64, len(in))
|
||||||
|
for i, c := range in {
|
||||||
|
scores[i] = metric(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute mean, variance
|
||||||
|
sum := 0.0
|
||||||
|
for _, s := range scores {
|
||||||
|
sum += s
|
||||||
|
}
|
||||||
|
sum /= float64(len(scores))
|
||||||
|
mean = sum
|
||||||
|
sum = 0.0
|
||||||
|
for _, s := range scores {
|
||||||
|
sum += (s - mean) * (s - mean)
|
||||||
|
}
|
||||||
|
sum /= float64(len(scores))
|
||||||
|
variance = sum
|
||||||
|
return mean, variance
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCrossFoldValidationConfusionMatrices divides the data into a number of folds
|
||||||
|
// then trains and evaluates the classifier on each fold, producing a new ConfusionMatrix.
|
||||||
|
func GenerateCrossFoldValidationConfusionMatrices(data base.FixedDataGrid, cls base.Classifier, folds int) ([]ConfusionMatrix, error) {
|
||||||
|
_, rows := data.Size()
|
||||||
|
|
||||||
|
// Assign each row to a fold
|
||||||
|
foldMap := make([]int, rows)
|
||||||
|
inverseFoldMap := make(map[int][]int)
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
fold := rand.Intn(folds)
|
||||||
|
foldMap[i] = fold
|
||||||
|
if _, ok := inverseFoldMap[fold]; !ok {
|
||||||
|
inverseFoldMap[fold] = make([]int, 0)
|
||||||
|
}
|
||||||
|
inverseFoldMap[fold] = append(inverseFoldMap[fold], i)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]ConfusionMatrix, folds)
|
||||||
|
|
||||||
|
// Create training/test views for each fold
|
||||||
|
for i := 0; i < folds; i++ {
|
||||||
|
// Fold i is for testing
|
||||||
|
testData := base.NewInstancesViewFromVisible(data, inverseFoldMap[i], data.AllAttributes())
|
||||||
|
otherRows := make([]int, 0)
|
||||||
|
for j := 0; j < folds; j++ {
|
||||||
|
if i == j {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
otherRows = append(otherRows, inverseFoldMap[j]...)
|
||||||
|
}
|
||||||
|
trainData := base.NewInstancesViewFromVisible(data, otherRows, data.AllAttributes())
|
||||||
|
// Train
|
||||||
|
err := cls.Fit(trainData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Predict
|
||||||
|
pred, err := cls.Predict(testData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Evaluate
|
||||||
|
cf, err := GetConfusionMatrix(testData, pred)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ret[i] = cf
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
}
|
40
examples/crossfold/rf.go
Normal file
40
examples/crossfold/rf.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Demonstrates decision tree classification
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
"github.com/sjwhitworth/golearn/ensemble"
|
||||||
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
var tree base.Classifier
|
||||||
|
|
||||||
|
// Load in the iris dataset
|
||||||
|
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < 60; i += 2 {
|
||||||
|
// Demonstrate the effect of adding more trees to the forest
|
||||||
|
// and also how much better it is without discretisation.
|
||||||
|
rand.Seed(44111342)
|
||||||
|
|
||||||
|
tree = ensemble.NewRandomForest(i, 4)
|
||||||
|
cfs, err := evaluation.GenerateCrossFoldValidationConfusionMatrices(iris, tree, 5)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mean, variance := evaluation.GetCrossValidatedMetric(cfs, evaluation.GetAccuracy)
|
||||||
|
stdev := math.Sqrt(variance)
|
||||||
|
|
||||||
|
fmt.Printf("%d\t%.2f\t(+/- %.2f)\n", i, mean, stdev*2)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user