mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
237 lines
6.5 KiB
Go
237 lines
6.5 KiB
Go
package evaluation
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"text/tabwriter"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
)
|
|
|
|
// ConfusionMatrix is a nested map of actual and predicted class counts
|
|
type ConfusionMatrix map[string]map[string]int
|
|
|
|
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
|
|
// and generate (`gen') Instances.
|
|
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) (map[string]map[string]int, error) {
|
|
_, refRows := ref.Size()
|
|
_, genRows := gen.Size()
|
|
|
|
if refRows != genRows {
|
|
return nil, errors.New(fmt.Sprintf("Row count mismatch: ref has %d rows, gen has %d rows", refRows, genRows))
|
|
}
|
|
|
|
ret := make(map[string]map[string]int)
|
|
|
|
for i := 0; i < int(refRows); i++ {
|
|
referenceClass := base.GetClass(ref, i)
|
|
predictedClass := base.GetClass(gen, i)
|
|
if _, ok := ret[referenceClass]; ok {
|
|
ret[referenceClass][predictedClass] += 1
|
|
} else {
|
|
ret[referenceClass] = make(map[string]int)
|
|
ret[referenceClass][predictedClass] = 1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// GetTruePositives returns the number of times an entry is
|
|
// predicted successfully in a given ConfusionMatrix.
|
|
func GetTruePositives(class string, c ConfusionMatrix) float64 {
|
|
return float64(c[class][class])
|
|
}
|
|
|
|
// GetFalsePositives returns the number of times an entry is
|
|
// incorrectly predicted as having a given class.
|
|
func GetFalsePositives(class string, c ConfusionMatrix) float64 {
|
|
ret := 0.0
|
|
for k := range c {
|
|
if k == class {
|
|
continue
|
|
}
|
|
ret += float64(c[k][class])
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// GetFalseNegatives returns the number of times an entry is
|
|
// incorrectly predicted as something other than the given class.
|
|
func GetFalseNegatives(class string, c ConfusionMatrix) float64 {
|
|
ret := 0.0
|
|
for k := range c[class] {
|
|
if k == class {
|
|
continue
|
|
}
|
|
ret += float64(c[class][k])
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// GetTrueNegatives returns the number of times an entry is
|
|
// correctly predicted as something other than the given class.
|
|
func GetTrueNegatives(class string, c ConfusionMatrix) float64 {
|
|
ret := 0.0
|
|
for k := range c {
|
|
if k == class {
|
|
continue
|
|
}
|
|
for l := range c[k] {
|
|
if l == class {
|
|
continue
|
|
}
|
|
ret += float64(c[k][l])
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// GetPrecision returns the fraction of of the total predictions
|
|
// for a given class which were correct.
|
|
func GetPrecision(class string, c ConfusionMatrix) float64 {
|
|
// Fraction of retrieved instances that are relevant
|
|
truePositives := GetTruePositives(class, c)
|
|
falsePositives := GetFalsePositives(class, c)
|
|
return truePositives / (truePositives + falsePositives)
|
|
}
|
|
|
|
// GetRecall returns the fraction of the total occurrences of a
|
|
// given class which were predicted.
|
|
func GetRecall(class string, c ConfusionMatrix) float64 {
|
|
// Fraction of relevant instances that are retrieved
|
|
truePositives := GetTruePositives(class, c)
|
|
falseNegatives := GetFalseNegatives(class, c)
|
|
return truePositives / (truePositives + falseNegatives)
|
|
}
|
|
|
|
// GetF1Score computes the harmonic mean of precision and recall
|
|
// (equivalently called F-measure)
|
|
func GetF1Score(class string, c ConfusionMatrix) float64 {
|
|
precision := GetPrecision(class, c)
|
|
recall := GetRecall(class, c)
|
|
return 2 * (precision * recall) / (precision + recall)
|
|
}
|
|
|
|
// GetAccuracy computes the overall classification accuracy
|
|
// That is (number of correctly classified instances) / total instances
|
|
func GetAccuracy(c ConfusionMatrix) float64 {
|
|
correct := 0
|
|
total := 0
|
|
for i := range c {
|
|
for j := range c[i] {
|
|
if i == j {
|
|
correct += c[i][j]
|
|
}
|
|
total += c[i][j]
|
|
}
|
|
}
|
|
return float64(correct) / float64(total)
|
|
}
|
|
|
|
// GetMicroPrecision assesses Classifier performance across
|
|
// all classes using the total true positives and false positives.
|
|
func GetMicroPrecision(c ConfusionMatrix) float64 {
|
|
truePositives := 0.0
|
|
falsePositives := 0.0
|
|
for k := range c {
|
|
truePositives += GetTruePositives(k, c)
|
|
falsePositives += GetFalsePositives(k, c)
|
|
}
|
|
return truePositives / (truePositives + falsePositives)
|
|
}
|
|
|
|
// GetMacroPrecision assesses Classifier performance across all
|
|
// classes by averaging the precision measures achieved for each class.
|
|
func GetMacroPrecision(c ConfusionMatrix) float64 {
|
|
precisionVals := 0.0
|
|
for k := range c {
|
|
precisionVals += GetPrecision(k, c)
|
|
}
|
|
return precisionVals / float64(len(c))
|
|
}
|
|
|
|
// GetMicroRecall assesses Classifier performance across all
|
|
// classes using the total true positives and false negatives.
|
|
func GetMicroRecall(c ConfusionMatrix) float64 {
|
|
truePositives := 0.0
|
|
falseNegatives := 0.0
|
|
for k := range c {
|
|
truePositives += GetTruePositives(k, c)
|
|
falseNegatives += GetFalseNegatives(k, c)
|
|
}
|
|
return truePositives / (truePositives + falseNegatives)
|
|
}
|
|
|
|
// GetMacroRecall assesses Classifier performance across all classes
|
|
// by averaging the recall measures achieved for each class
|
|
func GetMacroRecall(c ConfusionMatrix) float64 {
|
|
recallVals := 0.0
|
|
for k := range c {
|
|
recallVals += GetRecall(k, c)
|
|
}
|
|
return recallVals / float64(len(c))
|
|
}
|
|
|
|
// GetSummary returns a table of precision, recall, true positive,
|
|
// false positive, and true negatives for each class for a given
|
|
// ConfusionMatrix
|
|
func GetSummary(c ConfusionMatrix) string {
|
|
var buffer bytes.Buffer
|
|
w := new(tabwriter.Writer)
|
|
w.Init(&buffer, 0, 8, 0, '\t', 0)
|
|
|
|
fmt.Fprintln(w, "Reference Class\tTrue Positives\tFalse Positives\tTrue Negatives\tPrecision\tRecall\tF1 Score")
|
|
fmt.Fprintln(w, "---------------\t--------------\t---------------\t--------------\t---------\t------\t--------")
|
|
for k := range c {
|
|
tp := GetTruePositives(k, c)
|
|
fp := GetFalsePositives(k, c)
|
|
tn := GetTrueNegatives(k, c)
|
|
prec := GetPrecision(k, c)
|
|
rec := GetRecall(k, c)
|
|
f1 := GetF1Score(k, c)
|
|
|
|
fmt.Fprintf(w, "%s\t%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", k, tp, fp, tn, prec, rec, f1)
|
|
}
|
|
w.Flush()
|
|
buffer.WriteString(fmt.Sprintf("Overall accuracy: %.4f\n", GetAccuracy(c)))
|
|
|
|
return buffer.String()
|
|
}
|
|
|
|
// ShowConfusionMatrix return a human-readable version of a given
|
|
// ConfusionMatrix.
|
|
func ShowConfusionMatrix(c ConfusionMatrix) string {
|
|
var buffer bytes.Buffer
|
|
w := new(tabwriter.Writer)
|
|
w.Init(&buffer, 0, 8, 0, '\t', 0)
|
|
|
|
ref := make([]string, 0)
|
|
fmt.Fprintf(w, "Reference Class\t")
|
|
for k := range c {
|
|
fmt.Fprintf(w, "%s\t", k)
|
|
ref = append(ref, k)
|
|
}
|
|
fmt.Fprintf(w, "\n")
|
|
|
|
fmt.Fprintf(w, "---------------\t")
|
|
for _, v := range ref {
|
|
for t := 0; t < len(v); t++ {
|
|
fmt.Fprintf(w, "-")
|
|
}
|
|
fmt.Fprintf(w, "\t")
|
|
}
|
|
fmt.Fprintf(w, "\n")
|
|
|
|
for _, v := range ref {
|
|
fmt.Fprintf(w, "%s\t", v)
|
|
for _, v2 := range ref {
|
|
fmt.Fprintf(w, "%d\t", c[v][v2])
|
|
}
|
|
fmt.Fprintf(w, "\n")
|
|
}
|
|
w.Flush()
|
|
|
|
return buffer.String()
|
|
}
|