1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00
golearn/knn.go

44 lines
1.1 KiB
Go
Raw Normal View History

2013-12-26 13:05:16 +00:00
package main
import (
mat "github.com/skelterjohn/go.matrix"
"fmt"
)
type KNNClassifier struct {
Data mat.DenseMatrix
Name string
2013-12-26 19:52:36 +00:00
Labels []string
2013-12-26 13:05:16 +00:00
}
2013-12-26 19:52:36 +00:00
//Initialises a new classifier
func (KNN *KNNClassifier) New(name string, labels []string, numbers []float64, x int, y int){
KNN.Data = *mat.MakeDenseMatrix(numbers, x, y)
KNN.Name = name
KNN.Labels = labels
2013-12-26 13:05:16 +00:00
}
2013-12-26 19:52:36 +00:00
func (KNN *KNNClassifier) ComputeDistance(vector mat.DenseMatrix) mat.DenseMatrix {
//Add switches for different distance metrics
result, _ := KNN.Data.TimesDense(&vector)
return *result
}
func (KNN *KNNClassifier) Predict(vector mat.DenseMatrix) mat.DenseMatrix {
blah := KNN.ComputeDistance(vector)
//return *mat.Difference(&KNN.Data, &vector)
return blah
}
func (KNN *KNNClassifier) GetLabel(index int) string {
return KNN.Labels[index]
}
2013-12-26 13:05:16 +00:00
func main(){
knn := KNNClassifier{}
2013-12-26 19:52:36 +00:00
dense := *mat.MakeDenseMatrix([]float64{4,5,1,3,4,2},2,3)
knn.New("Testing", []string{"this sucks", "hiya"}, []float64{1,2,3,4,5,6},2,3)
//hey := knn.ComputeDistance(dense)
blof := knn.Predict(dense)
fmt.Println(blof)
2013-12-26 13:05:16 +00:00
}