1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/knn/knn_test.go

44 lines
1.0 KiB
Go
Raw Normal View History

2014-05-03 23:08:43 +01:00
package knn
import (
"github.com/sjwhitworth/golearn/base"
2014-05-03 23:08:43 +01:00
. "github.com/smartystreets/goconvey/convey"
"testing"
2014-05-03 23:08:43 +01:00
)
func TestKnnClassifier(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("knn_train.csv", false)
testingData, err2 := base.ParseCSVToInstances("knn_test.csv", false)
if err1 != nil {
t.Error(err1)
return
}
if err2 != nil {
t.Error(err2)
return
}
cls := NewKnnClassifier("euclidean", 2)
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
2014-08-02 16:22:14 +01:00
So(predictions, ShouldNotEqual, nil)
2014-05-03 23:08:43 +01:00
Convey("When predicting the label for our first vector", func() {
2014-08-02 16:22:14 +01:00
result := base.GetClass(predictions, 0)
2014-05-03 23:08:43 +01:00
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
2014-08-02 16:22:14 +01:00
result2 := base.GetClass(predictions, 1)
2014-05-03 23:08:43 +01:00
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}