1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/knn/knn.go

164 lines
4.3 KiB
Go
Raw Normal View History

// Package knn implements a K Nearest Neighbors object, capable of both classification
// and regression. It accepts data in the form of a slice of float64s, which are then reshaped
// into a X by Y matrix.
2014-01-04 19:31:33 +00:00
package knn
2013-12-28 18:41:13 +00:00
import (
2014-05-03 23:08:43 +01:00
"github.com/gonum/matrix/mat64"
2014-04-30 22:13:07 +08:00
base "github.com/sjwhitworth/golearn/base"
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
2014-04-30 22:13:07 +08:00
util "github.com/sjwhitworth/golearn/utilities"
2014-04-30 08:57:13 +01:00
)
2013-12-28 18:41:13 +00:00
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function.
2014-05-03 23:08:43 +01:00
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
2013-12-28 18:41:13 +00:00
type KNNClassifier struct {
2014-05-01 19:56:30 +01:00
base.BaseEstimator
TrainingData *base.Instances
DistanceFunc string
NearestNeighbours int
2013-12-28 18:41:13 +00:00
}
// NewKnnClassifier returns a new classifier
func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
KNN := KNNClassifier{}
KNN.DistanceFunc = distfunc
KNN.NearestNeighbours = neighbours
return &KNN
}
// Fit stores the training data for later
func (KNN *KNNClassifier) Fit(trainingData *base.Instances) {
KNN.TrainingData = trainingData
2013-12-28 18:41:13 +00:00
}
// PredictOne returns a classification for the vector, based on a vector input, using the KNN algorithm.
2014-05-03 23:08:43 +01:00
// See http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
func (KNN *KNNClassifier) PredictOne(vector []float64) string {
2013-12-28 18:41:13 +00:00
rows := KNN.TrainingData.Rows
2013-12-28 18:41:13 +00:00
rownumbers := make(map[int]float64)
2014-01-05 00:23:31 +00:00
labels := make([]string, 0)
maxmap := make(map[string]int)
2013-12-28 18:41:13 +00:00
convertedVector := util.FloatsToMatrix(vector)
2014-05-03 23:08:43 +01:00
// Check what distance function we are using
switch KNN.DistanceFunc {
case "euclidean":
{
euclidean := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
2014-05-03 23:08:43 +01:00
rowMat := util.FloatsToMatrix(row)
distance := euclidean.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
}
case "manhattan":
{
manhattan := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
2014-05-03 23:08:43 +01:00
rowMat := util.FloatsToMatrix(row)
distance := manhattan.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
}
2013-12-28 18:41:13 +00:00
}
sorted := util.SortIntMap(rownumbers)
values := sorted[:KNN.NearestNeighbours]
2013-12-28 18:41:13 +00:00
for _, elem := range values {
label := KNN.TrainingData.GetClass(elem)
labels = append(labels, label)
if _, ok := maxmap[label]; ok {
2014-07-18 13:25:18 +03:00
maxmap[label]++
} else {
maxmap[label] = 1
}
2013-12-28 18:41:13 +00:00
}
sortedlabels := util.SortStringMap(maxmap)
label := sortedlabels[0]
2014-05-03 23:08:43 +01:00
return label
2014-04-30 08:57:13 +01:00
}
2014-05-04 09:52:13 +01:00
func (KNN *KNNClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, KNN.PredictOne(what.GetRowVectorWithoutClass(i)))
}
return ret
}
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
2014-05-04 09:52:13 +01:00
type KNNRegressor struct {
base.BaseEstimator
Values []float64
DistanceFunc string
}
// NewKnnRegressor mints a new classifier.
func NewKnnRegressor(distfunc string) *KNNRegressor {
2014-05-04 09:52:13 +01:00
KNN := KNNRegressor{}
KNN.DistanceFunc = distfunc
return &KNN
}
func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) {
if rows != len(values) {
panic(mat64.ErrShape)
}
KNN.Data = mat64.NewDense(rows, cols, numbers)
KNN.Values = values
}
2014-05-04 09:52:13 +01:00
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
// Get the number of rows
rows, _ := KNN.Data.Dims()
rownumbers := make(map[int]float64)
labels := make([]float64, 0)
// Check what distance function we are using
switch KNN.DistanceFunc {
case "euclidean":
{
euclidean := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := euclidean.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
case "manhattan":
{
manhattan := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := manhattan.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
}
sorted := util.SortIntMap(rownumbers)
values := sorted[:K]
var sum float64
for _, elem := range values {
value := KNN.Values[elem]
labels = append(labels, value)
sum += value
}
average := sum / float64(K)
return average
}