mirror of
https://github.com/mainflux/mainflux.git
synced 2025-04-27 13:48:49 +08:00

Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com> Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com>
202 lines
6.5 KiB
Go
202 lines
6.5 KiB
Go
// Copyright ©2016 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stat
|
|
|
|
import (
|
|
"math"
|
|
"sort"
|
|
)
|
|
|
|
// ROC returns paired false positive rate (FPR) and true positive rate
|
|
// (TPR) values corresponding to cutoff points on the receiver operator
|
|
// characteristic (ROC) curve obtained when y is treated as a binary
|
|
// classifier for classes with weights. The cutoff thresholds used to
|
|
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
|
|
// are the true and false positive rates for y >= thresh[i].
|
|
//
|
|
// The input y and cutoffs must be sorted, and values in y must correspond
|
|
// to values in classes and weights. SortWeightedLabeled can be used to
|
|
// sort y together with classes and weights.
|
|
//
|
|
// For a given cutoff value, observations corresponding to entries in y
|
|
// greater than the cutoff value are classified as true, while those
|
|
// less than or equal to the cutoff value are classified as false. These
|
|
// assigned class labels are compared with the true values in the classes
|
|
// slice and used to calculate the FPR and TPR.
|
|
//
|
|
// If weights is nil, all weights are treated as 1. If weights is not nil
|
|
// it must have the same length as y and classes, otherwise ROC will panic.
|
|
//
|
|
// If cutoffs is nil or empty, all possible cutoffs are calculated,
|
|
// resulting in fpr and tpr having length one greater than the number of
|
|
// unique values in y. Otherwise fpr and tpr will be returned with the
|
|
// same length as cutoffs. floats.Span can be used to generate equally
|
|
// spaced cutoffs.
|
|
//
|
|
// More details about ROC curves are available at
|
|
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
|
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
|
|
if len(y) != len(classes) {
|
|
panic("stat: slice length mismatch")
|
|
}
|
|
if weights != nil && len(y) != len(weights) {
|
|
panic("stat: slice length mismatch")
|
|
}
|
|
if !sort.Float64sAreSorted(y) {
|
|
panic("stat: input must be sorted ascending")
|
|
}
|
|
if !sort.Float64sAreSorted(cutoffs) {
|
|
panic("stat: cutoff values must be sorted ascending")
|
|
}
|
|
if len(y) == 0 {
|
|
return nil, nil, nil
|
|
}
|
|
if len(cutoffs) == 0 {
|
|
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
|
|
cutoffs = make([]float64, len(y)+1)
|
|
} else {
|
|
cutoffs = cutoffs[:len(y)+1]
|
|
}
|
|
// Choose all possible cutoffs for unique values in y.
|
|
bin := 0
|
|
cutoffs[bin] = y[0]
|
|
for i, u := range y[1:] {
|
|
if u == y[i] {
|
|
continue
|
|
}
|
|
bin++
|
|
cutoffs[bin] = u
|
|
}
|
|
cutoffs[bin+1] = math.Inf(1)
|
|
cutoffs = cutoffs[:bin+2]
|
|
} else {
|
|
// Don't mutate the provided cutoffs.
|
|
tmp := cutoffs
|
|
cutoffs = make([]float64, len(cutoffs))
|
|
copy(cutoffs, tmp)
|
|
}
|
|
|
|
tpr = make([]float64, len(cutoffs))
|
|
fpr = make([]float64, len(cutoffs))
|
|
var bin int
|
|
var nPos, nNeg float64
|
|
for i, u := range classes {
|
|
// Update the bin until it matches the next y value
|
|
// skipping empty bins.
|
|
for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] {
|
|
bin++
|
|
tpr[bin] = tpr[bin-1]
|
|
fpr[bin] = fpr[bin-1]
|
|
}
|
|
posWeight, negWeight := 1.0, 0.0
|
|
if weights != nil {
|
|
posWeight = weights[i]
|
|
}
|
|
if !u {
|
|
posWeight, negWeight = negWeight, posWeight
|
|
}
|
|
nPos += posWeight
|
|
nNeg += negWeight
|
|
// Count false negatives (in tpr) and true negatives (in fpr).
|
|
if y[i] < cutoffs[bin] {
|
|
tpr[bin] += posWeight
|
|
fpr[bin] += negWeight
|
|
}
|
|
}
|
|
|
|
invNeg := 1 / nNeg
|
|
invPos := 1 / nPos
|
|
// Convert negative counts to TPR and FPR.
|
|
// Bins beyond the maximum value in y are skipped
|
|
// leaving these fpr and tpr elements as zero.
|
|
for i := range tpr[:bin+1] {
|
|
// Prevent fused float operations by
|
|
// making explicit float64 conversions.
|
|
tpr[i] = 1 - float64(tpr[i]*invPos)
|
|
fpr[i] = 1 - float64(fpr[i]*invNeg)
|
|
}
|
|
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
|
|
tpr[i], tpr[j] = tpr[j], tpr[i]
|
|
fpr[i], fpr[j] = fpr[j], fpr[i]
|
|
}
|
|
for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
|
|
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
|
|
}
|
|
|
|
return tpr, fpr, cutoffs
|
|
}
|
|
|
|
// TOC returns the Total Operating Characteristic for the classes provided
|
|
// and the minimum and maximum bounds for the TOC.
|
|
//
|
|
// The input y values that correspond to classes and weights must be sorted
|
|
// in ascending order. classes[i] is the class of value y[i] and weights[i]
|
|
// is the weight of y[i]. SortWeightedLabeled can be used to sort classes
|
|
// together with weights by the rank variable, i+1.
|
|
//
|
|
// The returned ntp values can be interpreted as the number of true positives
|
|
// where values above the given rank are assigned class true for each given
|
|
// rank from 1 to len(classes).
|
|
//
|
|
// ntp_i = sum_{j ≥ len(ntp)-1 - i} [ classes_j ] * weights_j, where [x] = 1 if x else 0.
|
|
//
|
|
// The values of min and max provide the minimum and maximum possible number
|
|
// of false values for the set of classes. The first element of ntp, min and
|
|
// max are always zero as this corresponds to assigning all data class false
|
|
// and the last elements are always weighted sum of classes as this corresponds
|
|
// to assigning every data class true. For len(classes) != 0, the lengths of
|
|
// min, ntp and max are len(classes)+1.
|
|
//
|
|
// If weights is nil, all weights are treated as 1. When weights are not nil,
|
|
// the calculation of min and max allows for partial assignment of single data
|
|
// points. If weights is not nil it must have the same length as classes,
|
|
// otherwise TOC will panic.
|
|
//
|
|
// More details about TOC curves are available at
|
|
// https://en.wikipedia.org/wiki/Total_operating_characteristic
|
|
func TOC(classes []bool, weights []float64) (min, ntp, max []float64) {
|
|
if weights != nil && len(classes) != len(weights) {
|
|
panic("stat: slice length mismatch")
|
|
}
|
|
if len(classes) == 0 {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
ntp = make([]float64, len(classes)+1)
|
|
min = make([]float64, len(ntp))
|
|
max = make([]float64, len(ntp))
|
|
if weights == nil {
|
|
for i := range ntp[1:] {
|
|
ntp[i+1] = ntp[i]
|
|
if classes[len(classes)-i-1] {
|
|
ntp[i+1]++
|
|
}
|
|
}
|
|
totalPositive := ntp[len(ntp)-1]
|
|
for i := range ntp {
|
|
min[i] = math.Max(0, totalPositive-float64(len(classes)-i))
|
|
max[i] = math.Min(totalPositive, float64(i))
|
|
}
|
|
return min, ntp, max
|
|
}
|
|
|
|
cumw := max // Reuse max for cumulative weight. Update its elements last.
|
|
for i := range ntp[1:] {
|
|
ntp[i+1] = ntp[i]
|
|
w := weights[len(weights)-i-1]
|
|
cumw[i+1] = cumw[i] + w
|
|
if classes[len(classes)-i-1] {
|
|
ntp[i+1] += w
|
|
}
|
|
}
|
|
totw := cumw[len(cumw)-1]
|
|
totalPositive := ntp[len(ntp)-1]
|
|
for i := range ntp {
|
|
min[i] = math.Max(0, totalPositive-(totw-cumw[i]))
|
|
max[i] = math.Min(totalPositive, cumw[i])
|
|
}
|
|
return min, ntp, max
|
|
}
|