1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

add weighted knn

This commit is contained in:
FrozenKP 2017-04-17 20:37:28 +08:00
parent 3a2782ffec
commit b33ef1f117

View File

@ -20,6 +20,7 @@ import (
// The accepted searching algorithm here are 'linear', and 'kdtree'.
// Optimisations only occur when things are identically group into identical
// AttributeGroups, which don't include the class variable, in the same order.
// Using weighted KNN when Weighted set to be true (default: false).
type KNNClassifier struct {
base.BaseEstimator
TrainingData base.FixedDataGrid
@ -27,6 +28,7 @@ type KNNClassifier struct {
Algorithm string
NearestNeighbours int
AllowOptimisations bool
Weighted bool
}
// NewKnnClassifier returns a new classifier
@ -35,6 +37,7 @@ func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier
KNN.DistanceFunc = distfunc
KNN.Algorithm = algorithm
KNN.NearestNeighbours = neighbours
KNN.Weighted = false
KNN.AllowOptimisations = true
return &KNN
}
@ -157,7 +160,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
distances := make(map[int]float64)
// Reserve storage for voting map
maxmap := make(map[string]int)
maxmapInt := make(map[string]int)
maxmapFloat := make(map[string]float64)
// Reserve storage for row computations
trainRowBuf := make([]float64, len(allNumericAttrs))
@ -185,6 +189,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
return nil, err
}
}
// Iterate over all outer rows
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
@ -217,15 +222,32 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
sorted := utilities.SortIntMap(distances)
values := sorted[:KNN.NearestNeighbours]
maxClass := KNN.vote(maxmap, values)
length := make([]float64, KNN.NearestNeighbours)
for k, v := range values {
length[k] = distances[v]
}
var maxClass string
if KNN.Weighted {
maxClass = KNN.weightedVote(maxmapFloat, values, length)
} else {
maxClass = KNN.vote(maxmapInt, values)
}
base.SetClass(ret, predRowNo, maxClass)
case "kdtree":
values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
values, length, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
if err != nil {
return false, err
}
maxClass := KNN.vote(maxmap, values)
var maxClass string
if KNN.Weighted {
maxClass = KNN.weightedVote(maxmapFloat, values, length)
} else {
maxClass = KNN.vote(maxmapInt, values)
}
base.SetClass(ret, predRowNo, maxClass)
}
@ -268,6 +290,34 @@ func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
return maxClass
}
func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int, length []float64) string {
// Reset maxMap
for a := range maxmap {
maxmap[a] = 0
}
// Refresh maxMap
for k, elem := range values {
label := base.GetClass(KNN.TrainingData, elem)
if _, ok := maxmap[label]; ok {
maxmap[label] += (1 / length[k])
} else {
maxmap[label] = (1 / length[k])
}
}
// Sort the maxMap
var maxClass string
maxVal := -1.0
for a := range maxmap {
if maxmap[a] > maxVal {
maxVal = maxmap[a]
maxClass = a
}
}
return maxClass
}
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
type KNNRegressor struct {
base.BaseEstimator