From b2f5b2840d37bb2f64a4c35219759fffb17e8f89 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sun, 26 Oct 2014 17:37:48 +0000 Subject: [PATCH] Cross-fold validation --- evaluation/cross_fold.go | 82 ++++++++++++++++++++++++++++++++++++++++ examples/crossfold/rf.go | 40 ++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 evaluation/cross_fold.go create mode 100644 examples/crossfold/rf.go diff --git a/evaluation/cross_fold.go b/evaluation/cross_fold.go new file mode 100644 index 0000000..c1fab5f --- /dev/null +++ b/evaluation/cross_fold.go @@ -0,0 +1,82 @@ +package evaluation + +import ( + "github.com/sjwhitworth/golearn/base" + "math/rand" +) + +// GetCrossValidatedMetric returns the mean and variance of the confusion-matrix-derived +// metric across all folds. +func GetCrossValidatedMetric(in []ConfusionMatrix, metric func(ConfusionMatrix) float64) (mean, variance float64) { + scores := make([]float64, len(in)) + for i, c := range in { + scores[i] = metric(c) + } + + // Compute mean, variance + sum := 0.0 + for _, s := range scores { + sum += s + } + sum /= float64(len(scores)) + mean = sum + sum = 0.0 + for _, s := range scores { + sum += (s - mean) * (s - mean) + } + sum /= float64(len(scores)) + variance = sum + return mean, variance +} + +// GenerateCrossFoldValidationConfusionMatrices divides the data into a number of folds +// then trains and evaluates the classifier on each fold, producing a new ConfusionMatrix. +func GenerateCrossFoldValidationConfusionMatrices(data base.FixedDataGrid, cls base.Classifier, folds int) ([]ConfusionMatrix, error) { + _, rows := data.Size() + + // Assign each row to a fold + foldMap := make([]int, rows) + inverseFoldMap := make(map[int][]int) + for i := 0; i < rows; i++ { + fold := rand.Intn(folds) + foldMap[i] = fold + if _, ok := inverseFoldMap[fold]; !ok { + inverseFoldMap[fold] = make([]int, 0) + } + inverseFoldMap[fold] = append(inverseFoldMap[fold], i) + } + + ret := make([]ConfusionMatrix, folds) + + // Create training/test views for each fold + for i := 0; i < folds; i++ { + // Fold i is for testing + testData := base.NewInstancesViewFromVisible(data, inverseFoldMap[i], data.AllAttributes()) + otherRows := make([]int, 0) + for j := 0; j < folds; j++ { + if i == j { + continue + } + otherRows = append(otherRows, inverseFoldMap[j]...) + } + trainData := base.NewInstancesViewFromVisible(data, otherRows, data.AllAttributes()) + // Train + err := cls.Fit(trainData) + if err != nil { + return nil, err + } + // Predict + pred, err := cls.Predict(testData) + if err != nil { + return nil, err + } + // Evaluate + cf, err := GetConfusionMatrix(testData, pred) + if err != nil { + return nil, err + } + ret[i] = cf + } + return ret, nil + +} diff --git a/examples/crossfold/rf.go b/examples/crossfold/rf.go new file mode 100644 index 0000000..5ba8942 --- /dev/null +++ b/examples/crossfold/rf.go @@ -0,0 +1,40 @@ +// Demonstrates decision tree classification + +package main + +import ( + "fmt" + "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/ensemble" + "github.com/sjwhitworth/golearn/evaluation" + "math" + "math/rand" +) + +func main() { + + var tree base.Classifier + + // Load in the iris dataset + iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + + for i := 1; i < 60; i += 2 { + // Demonstrate the effect of adding more trees to the forest + // and also how much better it is without discretisation. + rand.Seed(44111342) + + tree = ensemble.NewRandomForest(i, 4) + cfs, err := evaluation.GenerateCrossFoldValidationConfusionMatrices(iris, tree, 5) + if err != nil { + panic(err) + } + + mean, variance := evaluation.GetCrossValidatedMetric(cfs, evaluation.GetAccuracy) + stdev := math.Sqrt(variance) + + fmt.Printf("%d\t%.2f\t(+/- %.2f)\n", i, mean, stdev*2) + } +}