mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
108 lines
3.5 KiB
Go
108 lines
3.5 KiB
Go
package evaluation
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"testing"
|
|
)
|
|
|
|
func TestMetrics(t *testing.T) {
|
|
Convey("Quantities derived from a confusion matrix", t, func() {
|
|
confusionMat := make(ConfusionMatrix)
|
|
confusionMat["a"] = make(map[string]int)
|
|
confusionMat["b"] = make(map[string]int)
|
|
confusionMat["a"]["a"] = 75
|
|
confusionMat["a"]["b"] = 5
|
|
confusionMat["b"]["a"] = 10
|
|
confusionMat["b"]["b"] = 10
|
|
|
|
Convey("True Positives", func() {
|
|
So(GetTruePositives("a", confusionMat), ShouldAlmostEqual, 75, 1)
|
|
So(GetTruePositives("b", confusionMat), ShouldAlmostEqual, 10, 1)
|
|
})
|
|
|
|
Convey("True Negatives", func() {
|
|
So(GetTrueNegatives("a", confusionMat), ShouldAlmostEqual, 10, 1)
|
|
So(GetTrueNegatives("b", confusionMat), ShouldAlmostEqual, 75, 1)
|
|
})
|
|
|
|
Convey("False Positives", func() {
|
|
So(GetFalsePositives("a", confusionMat), ShouldAlmostEqual, 10, 1)
|
|
So(GetFalsePositives("b", confusionMat), ShouldAlmostEqual, 5, 1)
|
|
})
|
|
|
|
Convey("False Negatives", func() {
|
|
So(GetFalseNegatives("a", confusionMat), ShouldAlmostEqual, 5, 1)
|
|
So(GetFalseNegatives("b", confusionMat), ShouldAlmostEqual, 10, 1)
|
|
})
|
|
|
|
Convey("Precision", func() {
|
|
So(GetPrecision("a", confusionMat), ShouldAlmostEqual, 0.88, 0.01)
|
|
So(GetPrecision("b", confusionMat), ShouldAlmostEqual, 0.666, 0.01)
|
|
})
|
|
|
|
Convey("Recall", func() {
|
|
So(GetRecall("a", confusionMat), ShouldAlmostEqual, 0.94, 0.01)
|
|
So(GetRecall("b", confusionMat), ShouldAlmostEqual, 0.50, 0.01)
|
|
})
|
|
|
|
Convey("MicroPrecision", func() {
|
|
So(GetMicroPrecision(confusionMat), ShouldAlmostEqual, 0.85, 0.01)
|
|
})
|
|
|
|
Convey("MicroRecall", func() {
|
|
So(GetMicroRecall(confusionMat), ShouldAlmostEqual, 0.85, 0.01)
|
|
})
|
|
|
|
Convey("MacroPrecision", func() {
|
|
So(GetMacroPrecision(confusionMat), ShouldAlmostEqual, 0.775, 0.01)
|
|
})
|
|
|
|
Convey("MacroRecall", func() {
|
|
So(GetMacroRecall(confusionMat), ShouldAlmostEqual, 0.719, 0.01)
|
|
})
|
|
|
|
Convey("F1Score", func() {
|
|
So(GetF1Score("a", confusionMat), ShouldAlmostEqual, 0.91, 0.1)
|
|
So(GetF1Score("b", confusionMat), ShouldAlmostEqual, 0.571, 0.01)
|
|
})
|
|
|
|
Convey("Accuracy", func() {
|
|
So(GetAccuracy(confusionMat), ShouldAlmostEqual, 0.85, 0.1)
|
|
})
|
|
|
|
Convey("Get Summary", func() {
|
|
output := GetSummary(confusionMat)
|
|
So(output, ShouldStartWith, "Reference Class")
|
|
So(output, ShouldContainSubstring, "True Positives")
|
|
So(output, ShouldContainSubstring, "False Positives")
|
|
So(output, ShouldContainSubstring, "True Negatives")
|
|
So(output, ShouldContainSubstring, "Precision")
|
|
So(output, ShouldContainSubstring, "Recall")
|
|
So(output, ShouldContainSubstring, "F1 Score")
|
|
So(output, ShouldContainSubstring, "------")
|
|
So(output, ShouldContainSubstring, "Overall accuracy:")
|
|
})
|
|
|
|
Convey("Show Confusion Matrix", func() {
|
|
output := ShowConfusionMatrix(confusionMat)
|
|
So(output, ShouldStartWith, "Reference Class")
|
|
So(output, ShouldContainSubstring, "---------------")
|
|
})
|
|
|
|
Convey("Get Confusion Matrix", func() {
|
|
X, _ := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
Y, _ := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
|
|
Convey("Nomarl ref and gen matrices", func() {
|
|
out, _ := GetConfusionMatrix(X, X)
|
|
ret := make(map[string]map[string]int)
|
|
So(out, ShouldHaveSameTypeAs, ret)
|
|
})
|
|
Convey("Row count mismatch", func() {
|
|
_, err := GetConfusionMatrix(X, Y)
|
|
So(err.Error(), ShouldStartWith, "Row count mismatch:")
|
|
})
|
|
})
|
|
})
|
|
}
|