1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #54 from njern/knn_refactor

Knn refactor
This commit is contained in:
Stephen Whitworth 2014-07-26 20:18:28 -04:00
commit e6c28efe2d
2 changed files with 32 additions and 36 deletions

View File

@ -44,27 +44,21 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string {
convertedVector := util.FloatsToMatrix(vector)
// Check what distance function we are using
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
switch KNN.DistanceFunc {
case "euclidean":
{
euclidean := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
rowMat := util.FloatsToMatrix(row)
distance := euclidean.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
}
distanceFunc = pairwiseMetrics.NewEuclidean()
case "manhattan":
{
manhattan := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
rowMat := util.FloatsToMatrix(row)
distance := manhattan.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
}
distanceFunc = pairwiseMetrics.NewManhattan()
default:
panic("unsupported distance function")
}
for i := 0; i < rows; i++ {
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
rowMat := util.FloatsToMatrix(row)
distance := distanceFunc.Distance(rowMat, convertedVector)
rownumbers[i] = distance
}
sorted := util.SortIntMap(rownumbers)
@ -125,27 +119,21 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
labels := make([]float64, 0)
// Check what distance function we are using
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
switch KNN.DistanceFunc {
case "euclidean":
{
euclidean := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := euclidean.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
distanceFunc = pairwiseMetrics.NewEuclidean()
case "manhattan":
{
manhattan := pairwiseMetrics.NewEuclidean()
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := manhattan.Distance(rowMat, vector)
rownumbers[i] = distance
}
}
distanceFunc = pairwiseMetrics.NewManhattan()
default:
panic("unsupported distance function")
}
for i := 0; i < rows; i++ {
row := KNN.Data.RowView(i)
rowMat := util.FloatsToMatrix(row)
distance := distanceFunc.Distance(rowMat, vector)
rownumbers[i] = distance
}
sorted := util.SortIntMap(rownumbers)

View File

@ -1,2 +1,10 @@
// Package pairwise implements utilities to evaluate pairwise distances or inner product (via kernel).
package pairwise
import (
"github.com/gonum/matrix/mat64"
)
type PairwiseDistanceFunc interface {
Distance(vectorX *mat64.Dense, vectorY *mat64.Dense) float64
}