2013-12-26 13:05:16 +00:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
mat "github.com/skelterjohn/go.matrix"
|
2013-12-27 00:59:06 +00:00
|
|
|
rand "math/rand"
|
2013-12-26 13:05:16 +00:00
|
|
|
"fmt"
|
|
|
|
)
|
|
|
|
|
|
|
|
type KNNClassifier struct {
|
|
|
|
Data mat.DenseMatrix
|
|
|
|
Name string
|
2013-12-26 19:52:36 +00:00
|
|
|
Labels []string
|
2013-12-26 13:05:16 +00:00
|
|
|
}
|
|
|
|
|
2013-12-27 00:59:06 +00:00
|
|
|
func RandomArray(n int) []float64 {
|
|
|
|
ReturnedArray := make([]float64, n)
|
|
|
|
for i := 0; i < n; i++ {
|
|
|
|
ReturnedArray[i] = rand.Float64()
|
|
|
|
}
|
|
|
|
return ReturnedArray
|
|
|
|
}
|
|
|
|
|
|
|
|
//Mints a new classifier
|
2013-12-26 19:52:36 +00:00
|
|
|
func (KNN *KNNClassifier) New(name string, labels []string, numbers []float64, x int, y int){
|
|
|
|
KNN.Data = *mat.MakeDenseMatrix(numbers, x, y)
|
|
|
|
KNN.Name = name
|
|
|
|
KNN.Labels = labels
|
2013-12-26 13:05:16 +00:00
|
|
|
}
|
|
|
|
|
2013-12-27 00:59:06 +00:00
|
|
|
//Computes a variety of distance metrics between two vectors
|
|
|
|
func (KNN *KNNClassifier) ComputeDistance(vector *mat.DenseMatrix) *mat.DenseMatrix {
|
2013-12-26 19:52:36 +00:00
|
|
|
//Add switches for different distance metrics
|
2013-12-27 00:59:06 +00:00
|
|
|
result, err := KNN.Data.TimesDense(vector)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
fmt.Println(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
fmt.Println(result)
|
|
|
|
return result
|
2013-12-26 19:52:36 +00:00
|
|
|
}
|
|
|
|
|
2013-12-27 00:59:06 +00:00
|
|
|
//Returns a classification based on a vector input
|
|
|
|
func (KNN *KNNClassifier) Predict(vector mat.DenseMatrix) *mat.DenseMatrix {
|
|
|
|
return KNN.ComputeDistance(&vector)
|
2013-12-26 19:52:36 +00:00
|
|
|
}
|
|
|
|
|
2013-12-27 00:59:06 +00:00
|
|
|
//Returns a label, given an index
|
2013-12-26 19:52:36 +00:00
|
|
|
func (KNN *KNNClassifier) GetLabel(index int) string {
|
|
|
|
return KNN.Labels[index]
|
|
|
|
}
|
2013-12-26 13:05:16 +00:00
|
|
|
|
|
|
|
func main(){
|
2013-12-27 00:59:06 +00:00
|
|
|
for {
|
|
|
|
values := RandomArray(4)
|
|
|
|
knn := KNNClassifier{}
|
|
|
|
knn.New("Testing", []string{"this sucks", "hiya"}, values,2,2)
|
|
|
|
knn.Predict(knn.Data)
|
|
|
|
}
|
2013-12-26 13:05:16 +00:00
|
|
|
}
|