mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
add weighted knn
This commit is contained in:
parent
3a2782ffec
commit
b33ef1f117
58
knn/knn.go
58
knn/knn.go
@ -20,6 +20,7 @@ import (
|
||||
// The accepted searching algorithm here are 'linear', and 'kdtree'.
|
||||
// Optimisations only occur when things are identically group into identical
|
||||
// AttributeGroups, which don't include the class variable, in the same order.
|
||||
// Using weighted KNN when Weighted set to be true (default: false).
|
||||
type KNNClassifier struct {
|
||||
base.BaseEstimator
|
||||
TrainingData base.FixedDataGrid
|
||||
@ -27,6 +28,7 @@ type KNNClassifier struct {
|
||||
Algorithm string
|
||||
NearestNeighbours int
|
||||
AllowOptimisations bool
|
||||
Weighted bool
|
||||
}
|
||||
|
||||
// NewKnnClassifier returns a new classifier
|
||||
@ -35,6 +37,7 @@ func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier
|
||||
KNN.DistanceFunc = distfunc
|
||||
KNN.Algorithm = algorithm
|
||||
KNN.NearestNeighbours = neighbours
|
||||
KNN.Weighted = false
|
||||
KNN.AllowOptimisations = true
|
||||
return &KNN
|
||||
}
|
||||
@ -157,7 +160,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
distances := make(map[int]float64)
|
||||
|
||||
// Reserve storage for voting map
|
||||
maxmap := make(map[string]int)
|
||||
maxmapInt := make(map[string]int)
|
||||
maxmapFloat := make(map[string]float64)
|
||||
|
||||
// Reserve storage for row computations
|
||||
trainRowBuf := make([]float64, len(allNumericAttrs))
|
||||
@ -185,6 +189,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over all outer rows
|
||||
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
||||
|
||||
@ -217,15 +222,32 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
|
||||
sorted := utilities.SortIntMap(distances)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
maxClass := KNN.vote(maxmap, values)
|
||||
|
||||
length := make([]float64, KNN.NearestNeighbours)
|
||||
for k, v := range values {
|
||||
length[k] = distances[v]
|
||||
}
|
||||
|
||||
var maxClass string
|
||||
if KNN.Weighted {
|
||||
maxClass = KNN.weightedVote(maxmapFloat, values, length)
|
||||
} else {
|
||||
maxClass = KNN.vote(maxmapInt, values)
|
||||
}
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
|
||||
case "kdtree":
|
||||
values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
|
||||
values, length, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
maxClass := KNN.vote(maxmap, values)
|
||||
|
||||
var maxClass string
|
||||
if KNN.Weighted {
|
||||
maxClass = KNN.weightedVote(maxmapFloat, values, length)
|
||||
} else {
|
||||
maxClass = KNN.vote(maxmapInt, values)
|
||||
}
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
}
|
||||
|
||||
@ -268,6 +290,34 @@ func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
|
||||
return maxClass
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int, length []float64) string {
|
||||
// Reset maxMap
|
||||
for a := range maxmap {
|
||||
maxmap[a] = 0
|
||||
}
|
||||
|
||||
// Refresh maxMap
|
||||
for k, elem := range values {
|
||||
label := base.GetClass(KNN.TrainingData, elem)
|
||||
if _, ok := maxmap[label]; ok {
|
||||
maxmap[label] += (1 / length[k])
|
||||
} else {
|
||||
maxmap[label] = (1 / length[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the maxMap
|
||||
var maxClass string
|
||||
maxVal := -1.0
|
||||
for a := range maxmap {
|
||||
if maxmap[a] > maxVal {
|
||||
maxVal = maxmap[a]
|
||||
maxClass = a
|
||||
}
|
||||
}
|
||||
return maxClass
|
||||
}
|
||||
|
||||
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
||||
type KNNRegressor struct {
|
||||
base.BaseEstimator
|
||||
|
Loading…
x
Reference in New Issue
Block a user