2014-05-17 21:33:48 +01:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2016-10-10 19:45:20 -07:00
|
|
|
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
|
|
"github.com/sjwhitworth/golearn/knn"
|
2014-05-17 21:33:48 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
rawData, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
2014-08-02 16:22:14 +01:00
|
|
|
|
2014-05-17 21:33:48 +01:00
|
|
|
//Initialises a new KNN classifier
|
2017-04-19 00:23:57 +08:00
|
|
|
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
|
2014-05-17 21:33:48 +01:00
|
|
|
|
|
|
|
//Do a training-test split
|
2014-06-06 20:30:24 +02:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
2014-05-17 21:33:48 +01:00
|
|
|
cls.Fit(trainData)
|
|
|
|
|
|
|
|
//Calculates the Euclidean distance and returns the most popular label
|
2016-10-10 19:45:20 -07:00
|
|
|
predictions, err := cls.Predict(testData)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
2014-05-17 21:33:48 +01:00
|
|
|
fmt.Println(predictions)
|
|
|
|
|
|
|
|
// Prints precision/recall metrics
|
2014-08-22 08:52:37 +00:00
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
|
|
|
}
|
2014-05-17 21:33:48 +01:00
|
|
|
fmt.Println(evaluation.GetSummary(confusionMat))
|
|
|
|
}
|