diff --git a/evaluation/confusion.go b/evaluation/confusion.go index 3b00224..87fe2b1 100644 --- a/evaluation/confusion.go +++ b/evaluation/confusion.go @@ -11,19 +11,22 @@ type ConfusionMatrix map[string]map[string]int // GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref') // and generate (`gen') Instances. -func GetConfusionMatrix(ref *base.Instances, gen *base.Instances) map[string]map[string]int { +func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int { - if ref.Rows != gen.Rows { + _, refRows := ref.Size() + _, genRows := gen.Size() + + if refRows != genRows { panic("Row counts should match") } ret := make(map[string]map[string]int) - for i := 0; i < ref.Rows; i++ { - referenceClass := ref.GetClass(i) - predictedClass := gen.GetClass(i) + for i := 0; i < int(refRows); i++ { + referenceClass := base.GetClass(ref, i) + predictedClass := base.GetClass(gen, i) if _, ok := ret[referenceClass]; ok { - ret[referenceClass][predictedClass]++ + ret[referenceClass][predictedClass] += 1 } else { ret[referenceClass] = make(map[string]int) ret[referenceClass][predictedClass] = 1