mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Add test cases for cross_fold
This commit is contained in:
parent
e92e615e79
commit
7255c67138
24
evaluation/cross_fold_test.go
Normal file
24
evaluation/cross_fold_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package evaluation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
"github.com/sjwhitworth/golearn/trees"
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCrossFold(t *testing.T) {
|
||||||
|
Convey("Cross Fold Evaluation", t, func() {
|
||||||
|
iris, _ := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||||
|
tree := trees.NewID3DecisionTree(0.6)
|
||||||
|
cfs, _ := GenerateCrossFoldValidationConfusionMatrices(iris, tree, 5)
|
||||||
|
Convey("Cross Fold Validation Confusion Matrices", func() {
|
||||||
|
So(cfs, ShouldNotBeEmpty)
|
||||||
|
})
|
||||||
|
Convey("Cross Validated Metric", func() {
|
||||||
|
mean, variance := GetCrossValidatedMetric(cfs, GetAccuracy)
|
||||||
|
So(mean, ShouldBeBetween, 0.8, 1)
|
||||||
|
So(variance, ShouldBeBetween, 0, 0.03)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user