mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
commit
e6c28efe2d
60
knn/knn.go
60
knn/knn.go
@ -44,27 +44,21 @@ func (KNN *KNNClassifier) PredictOne(vector []float64) string {
|
||||
convertedVector := util.FloatsToMatrix(vector)
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
case "euclidean":
|
||||
{
|
||||
euclidean := pairwiseMetrics.NewEuclidean()
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := euclidean.Distance(rowMat, convertedVector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
}
|
||||
distanceFunc = pairwiseMetrics.NewEuclidean()
|
||||
case "manhattan":
|
||||
{
|
||||
manhattan := pairwiseMetrics.NewEuclidean()
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := manhattan.Distance(rowMat, convertedVector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
}
|
||||
distanceFunc = pairwiseMetrics.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
}
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.TrainingData.GetRowVectorWithoutClass(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := distanceFunc.Distance(rowMat, convertedVector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
|
||||
sorted := util.SortIntMap(rownumbers)
|
||||
@ -125,27 +119,21 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
||||
labels := make([]float64, 0)
|
||||
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwiseMetrics.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
case "euclidean":
|
||||
{
|
||||
euclidean := pairwiseMetrics.NewEuclidean()
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.Data.RowView(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := euclidean.Distance(rowMat, vector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
}
|
||||
distanceFunc = pairwiseMetrics.NewEuclidean()
|
||||
case "manhattan":
|
||||
{
|
||||
manhattan := pairwiseMetrics.NewEuclidean()
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.Data.RowView(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := manhattan.Distance(rowMat, vector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
}
|
||||
distanceFunc = pairwiseMetrics.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
}
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
row := KNN.Data.RowView(i)
|
||||
rowMat := util.FloatsToMatrix(row)
|
||||
distance := distanceFunc.Distance(rowMat, vector)
|
||||
rownumbers[i] = distance
|
||||
}
|
||||
|
||||
sorted := util.SortIntMap(rownumbers)
|
||||
|
@ -1,2 +1,10 @@
|
||||
// Package pairwise implements utilities to evaluate pairwise distances or inner product (via kernel).
|
||||
package pairwise
|
||||
|
||||
import (
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
type PairwiseDistanceFunc interface {
|
||||
Distance(vectorX *mat64.Dense, vectorY *mat64.Dense) float64
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user