1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/knn/knn.go

94 lines
2.2 KiB
Go
Raw Normal View History

package main
2013-12-28 18:41:13 +00:00
import (
mat "github.com/skelterjohn/go.matrix"
"math"
"fmt"
util "golearn/utilities"
base "golearn/base"
2013-12-28 23:48:12 +00:00
// "errors""
2013-12-28 18:41:13 +00:00
)
2013-12-28 23:48:12 +00:00
//A KNN Classifier. Consists of a data matrix, associated labels in the same order as the matrix, and a name.
2013-12-28 18:41:13 +00:00
type KNNClassifier struct {
base.BaseClassifier
2013-12-28 18:41:13 +00:00
}
2013-12-28 23:48:12 +00:00
//Mints a new classifier.
2013-12-28 18:41:13 +00:00
func (KNN *KNNClassifier) New(name string, labels []string, numbers []float64, x int, y int) {
2014-01-04 11:12:06 +00:00
//Write in some error handling here
2013-12-28 18:41:13 +00:00
// if x != len(KNN.Labels) {
// return errors.New("KNN: There must be a label for each row")
// }
KNN.Data = *mat.MakeDenseMatrix(numbers, x, y)
KNN.Name = name
KNN.Labels = labels
}
2013-12-28 23:48:12 +00:00
//Computes the Euclidean distance between two vectors.
2013-12-28 18:41:13 +00:00
func (KNN *KNNClassifier) ComputeDistance(vector *mat.DenseMatrix, testrow *mat.DenseMatrix) float64 {
var sum float64
difference, err := testrow.MinusDense(vector)
flat := difference.Array()
if err != nil {
fmt.Println(err)
}
for _, i := range flat {
squared := math.Pow(i, 2)
sum += squared
}
eucdistance := math.Sqrt(sum)
return eucdistance
}
2013-12-28 23:48:12 +00:00
//Returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(vector *mat.DenseMatrix, K int) (string, []int) {
2013-12-28 18:41:13 +00:00
rows := KNN.Data.Rows()
rownumbers := make(map[int]float64)
2013-12-28 23:48:12 +00:00
labels := make([]string, 1)
maxmap := make(map[string]int)
2013-12-28 18:41:13 +00:00
for i := 0; i < rows; i++{
row := KNN.Data.GetRowVector(i)
eucdistance := KNN.ComputeDistance(row, vector)
rownumbers[i] = eucdistance
}
sorted := util.SortIntMap(rownumbers)
2013-12-28 18:41:13 +00:00
values := sorted[:K]
for _, elem := range values {
labels = append(labels, KNN.Labels[elem])
if _, ok := maxmap[KNN.Labels[elem]]; ok {
maxmap[KNN.Labels[elem]] += 1
} else {
maxmap[KNN.Labels[elem]] = 1
}
2013-12-28 18:41:13 +00:00
}
sortedlabels := util.SortStringMap(maxmap)
label := sortedlabels[0]
return label, values
2013-12-28 18:41:13 +00:00
}
func main(){
cols, rows, _, labels, data := base.ParseCsv("../datasets/iris.csv", 4, []int{0,1,2})
2013-12-28 23:48:12 +00:00
knn := KNNClassifier{}
knn.New("Testing", labels, data, rows, cols)
2013-12-28 18:41:13 +00:00
for {
randArray := util.RandomArray(3)
random := mat.MakeDenseMatrix(randArray,1,3)
labels, _ := knn.Predict(random, 3)
2013-12-28 23:48:12 +00:00
fmt.Println(labels)
2013-12-28 18:41:13 +00:00
}
}