1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00
golearn/knn/knn.go

59 lines
1.5 KiB
Go
Raw Normal View History

2014-01-04 19:31:33 +00:00
package knn
2013-12-28 18:41:13 +00:00
import (
2014-04-30 22:13:07 +08:00
base "github.com/sjwhitworth/golearn/base"
util "github.com/sjwhitworth/golearn/utilities"
2014-04-30 08:57:13 +01:00
mat "github.com/skelterjohn/go.matrix"
)
2013-12-28 18:41:13 +00:00
2013-12-28 23:48:12 +00:00
//A KNN Classifier. Consists of a data matrix, associated labels in the same order as the matrix, and a name.
2013-12-28 18:41:13 +00:00
type KNNClassifier struct {
2014-05-01 19:56:30 +01:00
base.BaseEstimator
Labels []string
DistanceFunc string
2013-12-28 18:41:13 +00:00
}
2013-12-28 23:48:12 +00:00
//Mints a new classifier.
2014-05-01 19:56:30 +01:00
func (KNN *KNNClassifier) New(labels []string, numbers []float64, x int, y int, distfunc string) {
2013-12-28 18:41:13 +00:00
2014-05-01 19:56:30 +01:00
KNN.Data = mat.MakeDenseMatrix(numbers, x, y)
2013-12-28 18:41:13 +00:00
KNN.Labels = labels
2014-05-01 19:56:30 +01:00
KNN.DistanceFunc = distfunc
2013-12-28 18:41:13 +00:00
}
2014-05-01 19:56:30 +01:00
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
// @todo: Lots of room to improve this. V messy.
func (KNN *KNNClassifier) Predict(vector *mat.DenseMatrix, K int) (string, []int) {
2013-12-28 18:41:13 +00:00
rows := KNN.Data.Rows()
rownumbers := make(map[int]float64)
2014-01-05 00:23:31 +00:00
labels := make([]string, 0)
maxmap := make(map[string]int)
2013-12-28 18:41:13 +00:00
2014-04-30 08:57:13 +01:00
for i := 0; i < rows; i++ {
2013-12-28 18:41:13 +00:00
row := KNN.Data.GetRowVector(i)
2014-05-01 19:56:30 +01:00
//Will put code in to check errs later
eucdistance, _ := util.ComputeDistance(KNN.DistanceFunc, row, vector)
2013-12-28 18:41:13 +00:00
rownumbers[i] = eucdistance
}
sorted := util.SortIntMap(rownumbers)
2013-12-28 18:41:13 +00:00
values := sorted[:K]
for _, elem := range values {
labels = append(labels, KNN.Labels[elem])
if _, ok := maxmap[KNN.Labels[elem]]; ok {
maxmap[KNN.Labels[elem]] += 1
} else {
maxmap[KNN.Labels[elem]] = 1
}
2013-12-28 18:41:13 +00:00
}
sortedlabels := util.SortStringMap(maxmap)
label := sortedlabels[0]
return label, values
2014-04-30 08:57:13 +01:00
}