2014-05-01 21:20:44 +01:00
|
|
|
//@todo: A lot of code duplication here.
|
|
|
|
|
2014-01-05 00:23:31 +00:00
|
|
|
package knn
|
|
|
|
|
|
|
|
import (
|
2014-05-03 23:08:43 +01:00
|
|
|
"github.com/gonum/matrix/mat64"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
|
|
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
|
2014-04-30 22:13:07 +08:00
|
|
|
util "github.com/sjwhitworth/golearn/utilities"
|
2014-04-30 08:57:13 +01:00
|
|
|
)
|
2014-01-05 00:23:31 +00:00
|
|
|
|
|
|
|
//A KNN Regressor. Consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
|
|
|
type KNNRegressor struct {
|
2014-05-03 23:08:43 +01:00
|
|
|
base.BaseEstimator
|
|
|
|
Values []float64
|
2014-05-01 21:20:44 +01:00
|
|
|
DistanceFunc string
|
2014-01-05 00:23:31 +00:00
|
|
|
}
|
|
|
|
|
2014-05-03 23:08:43 +01:00
|
|
|
// Mints a new classifier.
|
|
|
|
func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor {
|
2014-05-01 21:20:44 +01:00
|
|
|
KNN := KNNRegressor{}
|
2014-05-03 23:08:43 +01:00
|
|
|
KNN.Data = mat64.NewDense(x, y, numbers)
|
|
|
|
KNN.Values = values
|
|
|
|
KNN.DistanceFunc = distfunc
|
2014-05-01 21:20:44 +01:00
|
|
|
return &KNN
|
2014-01-05 00:23:31 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
//Returns an average of the K nearest labels/variables, based on a vector input.
|
2014-05-03 23:08:43 +01:00
|
|
|
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
2014-01-05 00:23:31 +00:00
|
|
|
|
2014-05-03 23:08:43 +01:00
|
|
|
// Get the number of rows
|
|
|
|
rows, _ := KNN.Data.Dims()
|
2014-01-05 00:23:31 +00:00
|
|
|
rownumbers := make(map[int]float64)
|
2014-05-03 23:08:43 +01:00
|
|
|
labels := make([]float64, 0)
|
2014-01-05 00:23:31 +00:00
|
|
|
|
2014-05-03 23:08:43 +01:00
|
|
|
// Check what distance function we are using
|
|
|
|
switch KNN.DistanceFunc {
|
|
|
|
case "euclidean":
|
|
|
|
{
|
|
|
|
euclidean := pairwiseMetrics.NewEuclidean()
|
|
|
|
for i := 0; i < rows; i++ {
|
|
|
|
row := KNN.Data.RowView(i)
|
|
|
|
rowMat := util.FloatsToMatrix(row)
|
|
|
|
distance := euclidean.Distance(rowMat, vector)
|
|
|
|
rownumbers[i] = distance
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case "manhattan":
|
|
|
|
{
|
|
|
|
manhattan := pairwiseMetrics.NewEuclidean()
|
|
|
|
for i := 0; i < rows; i++ {
|
|
|
|
row := KNN.Data.RowView(i)
|
|
|
|
rowMat := util.FloatsToMatrix(row)
|
|
|
|
distance := manhattan.Distance(rowMat, vector)
|
|
|
|
rownumbers[i] = distance
|
|
|
|
}
|
|
|
|
}
|
2014-01-05 00:23:31 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
sorted := util.SortIntMap(rownumbers)
|
|
|
|
values := sorted[:K]
|
|
|
|
|
2014-05-03 23:08:43 +01:00
|
|
|
var sum float64
|
2014-01-05 00:23:31 +00:00
|
|
|
for _, elem := range values {
|
2014-05-03 23:08:43 +01:00
|
|
|
value := KNN.Values[elem]
|
2014-01-05 00:23:31 +00:00
|
|
|
labels = append(labels, value)
|
|
|
|
sum += value
|
|
|
|
}
|
|
|
|
|
|
|
|
average := sum / float64(K)
|
2014-05-03 23:08:43 +01:00
|
|
|
return average
|
2014-04-30 08:57:13 +01:00
|
|
|
}
|