1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-05-01 22:18:10 +08:00
golearn/knn.go

60 lines
1.3 KiB
Go
Raw Normal View History

2013-12-26 13:05:16 +00:00
package main
import (
mat "github.com/skelterjohn/go.matrix"
2013-12-27 00:59:06 +00:00
rand "math/rand"
2013-12-26 13:05:16 +00:00
"fmt"
)
type KNNClassifier struct {
Data mat.DenseMatrix
Name string
2013-12-26 19:52:36 +00:00
Labels []string
2013-12-26 13:05:16 +00:00
}
2013-12-27 00:59:06 +00:00
func RandomArray(n int) []float64 {
ReturnedArray := make([]float64, n)
for i := 0; i < n; i++ {
ReturnedArray[i] = rand.Float64()
}
return ReturnedArray
}
//Mints a new classifier
2013-12-26 19:52:36 +00:00
func (KNN *KNNClassifier) New(name string, labels []string, numbers []float64, x int, y int){
KNN.Data = *mat.MakeDenseMatrix(numbers, x, y)
KNN.Name = name
KNN.Labels = labels
2013-12-26 13:05:16 +00:00
}
2013-12-27 00:59:06 +00:00
//Computes a variety of distance metrics between two vectors
func (KNN *KNNClassifier) ComputeDistance(vector *mat.DenseMatrix) *mat.DenseMatrix {
2013-12-26 19:52:36 +00:00
//Add switches for different distance metrics
2013-12-27 00:59:06 +00:00
result, err := KNN.Data.TimesDense(vector)
if err != nil {
fmt.Println(err)
}
fmt.Println(result)
return result
2013-12-26 19:52:36 +00:00
}
2013-12-27 00:59:06 +00:00
//Returns a classification based on a vector input
func (KNN *KNNClassifier) Predict(vector mat.DenseMatrix) *mat.DenseMatrix {
return KNN.ComputeDistance(&vector)
2013-12-26 19:52:36 +00:00
}
2013-12-27 00:59:06 +00:00
//Returns a label, given an index
2013-12-26 19:52:36 +00:00
func (KNN *KNNClassifier) GetLabel(index int) string {
return KNN.Labels[index]
}
2013-12-26 13:05:16 +00:00
func main(){
2013-12-27 00:59:06 +00:00
for {
values := RandomArray(4)
knn := KNNClassifier{}
knn.New("Testing", []string{"this sucks", "hiya"}, values,2,2)
knn.Predict(knn.Data)
}
2013-12-26 13:05:16 +00:00
}