1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/knn/knnregressor.go

73 lines
1.8 KiB
Go
Raw Normal View History

2014-05-01 21:20:44 +01:00
//@todo: A lot of code duplication here.
2014-01-05 00:23:31 +00:00
package knn
import (
2014-05-03 23:08:43 +01:00
"github.com/gonum/matrix/mat64"
"github.com/sjwhitworth/golearn/base"
pairwiseMetrics "github.com/sjwhitworth/golearn/metrics/pairwise"
2014-04-30 22:13:07 +08:00
util "github.com/sjwhitworth/golearn/utilities"
2014-04-30 08:57:13 +01:00
)
2014-01-05 00:23:31 +00:00
//A KNN Regressor. Consists of a data matrix, associated result variables in the same order as the matrix, and a name.
type KNNRegressor struct {
2014-05-03 23:08:43 +01:00
base.BaseEstimator
Values []float64
2014-05-01 21:20:44 +01:00
DistanceFunc string
2014-01-05 00:23:31 +00:00
}
2014-05-03 23:08:43 +01:00
// Mints a new classifier.
func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor {
2014-05-01 21:20:44 +01:00
KNN := KNNRegressor{}
2014-05-03 23:08:43 +01:00
KNN.Data = mat64.NewDense(x, y, numbers)
KNN.Values = values
KNN.DistanceFunc = distfunc
2014-05-01 21:20:44 +01:00
return &KNN
2014-01-05 00:23:31 +00:00
}
//Returns an average of the K nearest labels/variables, based on a vector input.
2014-05-03 23:08:43 +01:00
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
2014-01-05 00:23:31 +00:00
2014-05-03 23:08:43 +01:00
// Get the number of rows
rows, _ := KNN.Data.Dims()
2014-01-05 00:23:31 +00:00
rownumbers := make(map[int]float64)
2014-05-03 23:08:43 +01:00
labels := make([]float64, 0)
2014-01-05 00:23:31 +00:00
2014-05-03 23:08:43 +01:00
// Check what distance function we are using
switch KNN.DistanceFunc {
case "euclidean":
{
euclidean := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := euclidean.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
case "manhattan":
{
manhattan := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := manhattan.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
2014-01-05 00:23:31 +00:00
}
sorted := util.SortIntMap(rownumbers)
values := sorted[:K]
2014-05-03 23:08:43 +01:00
var sum float64
2014-01-05 00:23:31 +00:00
for _, elem := range values {
2014-05-03 23:08:43 +01:00
value := KNN.Values[elem]
2014-01-05 00:23:31 +00:00
labels = append(labels, value)
sum += value
}
average := sum / float64(K)
2014-05-03 23:08:43 +01:00
return average
2014-04-30 08:57:13 +01:00
}