1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/evaluation/confusion.go
2018-06-16 22:11:59 +08:00

237 lines
6.5 KiB
Go

package evaluation
import (
"bytes"
"errors"
"fmt"
"text/tabwriter"
"github.com/sjwhitworth/golearn/base"
)
// ConfusionMatrix is a nested map of actual and predicted class counts
type ConfusionMatrix map[string]map[string]int
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) (map[string]map[string]int, error) {
_, refRows := ref.Size()
_, genRows := gen.Size()
if refRows != genRows {
return nil, errors.New(fmt.Sprintf("Row count mismatch: ref has %d rows, gen has %d rows", refRows, genRows))
}
ret := make(map[string]map[string]int)
for i := 0; i < int(refRows); i++ {
referenceClass := base.GetClass(ref, i)
predictedClass := base.GetClass(gen, i)
if _, ok := ret[referenceClass]; ok {
ret[referenceClass][predictedClass] += 1
} else {
ret[referenceClass] = make(map[string]int)
ret[referenceClass][predictedClass] = 1
}
}
return ret, nil
}
// GetTruePositives returns the number of times an entry is
// predicted successfully in a given ConfusionMatrix.
func GetTruePositives(class string, c ConfusionMatrix) float64 {
return float64(c[class][class])
}
// GetFalsePositives returns the number of times an entry is
// incorrectly predicted as having a given class.
func GetFalsePositives(class string, c ConfusionMatrix) float64 {
ret := 0.0
for k := range c {
if k == class {
continue
}
ret += float64(c[k][class])
}
return ret
}
// GetFalseNegatives returns the number of times an entry is
// incorrectly predicted as something other than the given class.
func GetFalseNegatives(class string, c ConfusionMatrix) float64 {
ret := 0.0
for k := range c[class] {
if k == class {
continue
}
ret += float64(c[class][k])
}
return ret
}
// GetTrueNegatives returns the number of times an entry is
// correctly predicted as something other than the given class.
func GetTrueNegatives(class string, c ConfusionMatrix) float64 {
ret := 0.0
for k := range c {
if k == class {
continue
}
for l := range c[k] {
if l == class {
continue
}
ret += float64(c[k][l])
}
}
return ret
}
// GetPrecision returns the fraction of of the total predictions
// for a given class which were correct.
func GetPrecision(class string, c ConfusionMatrix) float64 {
// Fraction of retrieved instances that are relevant
truePositives := GetTruePositives(class, c)
falsePositives := GetFalsePositives(class, c)
return truePositives / (truePositives + falsePositives)
}
// GetRecall returns the fraction of the total occurrences of a
// given class which were predicted.
func GetRecall(class string, c ConfusionMatrix) float64 {
// Fraction of relevant instances that are retrieved
truePositives := GetTruePositives(class, c)
falseNegatives := GetFalseNegatives(class, c)
return truePositives / (truePositives + falseNegatives)
}
// GetF1Score computes the harmonic mean of precision and recall
// (equivalently called F-measure)
func GetF1Score(class string, c ConfusionMatrix) float64 {
precision := GetPrecision(class, c)
recall := GetRecall(class, c)
return 2 * (precision * recall) / (precision + recall)
}
// GetAccuracy computes the overall classification accuracy
// That is (number of correctly classified instances) / total instances
func GetAccuracy(c ConfusionMatrix) float64 {
correct := 0
total := 0
for i := range c {
for j := range c[i] {
if i == j {
correct += c[i][j]
}
total += c[i][j]
}
}
return float64(correct) / float64(total)
}
// GetMicroPrecision assesses Classifier performance across
// all classes using the total true positives and false positives.
func GetMicroPrecision(c ConfusionMatrix) float64 {
truePositives := 0.0
falsePositives := 0.0
for k := range c {
truePositives += GetTruePositives(k, c)
falsePositives += GetFalsePositives(k, c)
}
return truePositives / (truePositives + falsePositives)
}
// GetMacroPrecision assesses Classifier performance across all
// classes by averaging the precision measures achieved for each class.
func GetMacroPrecision(c ConfusionMatrix) float64 {
precisionVals := 0.0
for k := range c {
precisionVals += GetPrecision(k, c)
}
return precisionVals / float64(len(c))
}
// GetMicroRecall assesses Classifier performance across all
// classes using the total true positives and false negatives.
func GetMicroRecall(c ConfusionMatrix) float64 {
truePositives := 0.0
falseNegatives := 0.0
for k := range c {
truePositives += GetTruePositives(k, c)
falseNegatives += GetFalseNegatives(k, c)
}
return truePositives / (truePositives + falseNegatives)
}
// GetMacroRecall assesses Classifier performance across all classes
// by averaging the recall measures achieved for each class
func GetMacroRecall(c ConfusionMatrix) float64 {
recallVals := 0.0
for k := range c {
recallVals += GetRecall(k, c)
}
return recallVals / float64(len(c))
}
// GetSummary returns a table of precision, recall, true positive,
// false positive, and true negatives for each class for a given
// ConfusionMatrix
func GetSummary(c ConfusionMatrix) string {
var buffer bytes.Buffer
w := new(tabwriter.Writer)
w.Init(&buffer, 0, 8, 0, '\t', 0)
fmt.Fprintln(w, "Reference Class\tTrue Positives\tFalse Positives\tTrue Negatives\tPrecision\tRecall\tF1 Score")
fmt.Fprintln(w, "---------------\t--------------\t---------------\t--------------\t---------\t------\t--------")
for k := range c {
tp := GetTruePositives(k, c)
fp := GetFalsePositives(k, c)
tn := GetTrueNegatives(k, c)
prec := GetPrecision(k, c)
rec := GetRecall(k, c)
f1 := GetF1Score(k, c)
fmt.Fprintf(w, "%s\t%.0f\t%.0f\t%.0f\t%.4f\t%.4f\t%.4f\n", k, tp, fp, tn, prec, rec, f1)
}
w.Flush()
buffer.WriteString(fmt.Sprintf("Overall accuracy: %.4f\n", GetAccuracy(c)))
return buffer.String()
}
// ShowConfusionMatrix return a human-readable version of a given
// ConfusionMatrix.
func ShowConfusionMatrix(c ConfusionMatrix) string {
var buffer bytes.Buffer
w := new(tabwriter.Writer)
w.Init(&buffer, 0, 8, 0, '\t', 0)
ref := make([]string, 0)
fmt.Fprintf(w, "Reference Class\t")
for k := range c {
fmt.Fprintf(w, "%s\t", k)
ref = append(ref, k)
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "---------------\t")
for _, v := range ref {
for t := 0; t < len(v); t++ {
fmt.Fprintf(w, "-")
}
fmt.Fprintf(w, "\t")
}
fmt.Fprintf(w, "\n")
for _, v := range ref {
fmt.Fprintf(w, "%s\t", v)
for _, v2 := range ref {
fmt.Fprintf(w, "%d\t", c[v][v2])
}
fmt.Fprintf(w, "\n")
}
w.Flush()
return buffer.String()
}