2014-05-03 23:08:43 +01:00
|
|
|
package knn
|
|
|
|
|
|
|
|
import (
|
2016-10-10 19:45:20 -07:00
|
|
|
"testing"
|
|
|
|
|
2014-05-09 18:21:31 +01:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-05-03 23:08:43 +01:00
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
|
|
)
|
|
|
|
|
2014-09-18 21:03:04 +01:00
|
|
|
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
2014-05-03 23:08:43 +01:00
|
|
|
Convey("Given labels, a classifier and data", t, func() {
|
2016-07-11 23:01:16 +01:00
|
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
2014-08-22 13:16:11 +00:00
|
|
|
So(err, ShouldBeNil)
|
2014-05-09 18:21:31 +01:00
|
|
|
|
2016-07-11 23:01:16 +01:00
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
2014-08-22 13:16:11 +00:00
|
|
|
So(err, ShouldBeNil)
|
2014-05-09 18:21:31 +01:00
|
|
|
|
2017-04-17 15:20:31 +08:00
|
|
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
2014-09-18 21:03:04 +01:00
|
|
|
cls.AllowOptimisations = false
|
|
|
|
cls.Fit(trainingData)
|
2016-10-10 19:45:20 -07:00
|
|
|
predictions, err := cls.Predict(testingData)
|
|
|
|
So(err, ShouldBeNil)
|
2014-09-18 21:03:04 +01:00
|
|
|
So(predictions, ShouldNotEqual, nil)
|
|
|
|
|
|
|
|
Convey("When predicting the label for our first vector", func() {
|
|
|
|
result := base.GetClass(predictions, 0)
|
|
|
|
Convey("The result should be 'blue", func() {
|
|
|
|
So(result, ShouldEqual, "blue")
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
Convey("When predicting the label for our second vector", func() {
|
|
|
|
result2 := base.GetClass(predictions, 1)
|
|
|
|
Convey("The result should be 'red", func() {
|
|
|
|
So(result2, ShouldEqual, "red")
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
|
|
|
Convey("Given labels, a classifier and data", t, func() {
|
2016-07-11 23:01:16 +01:00
|
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
2014-09-18 21:03:04 +01:00
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
2016-07-11 23:01:16 +01:00
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
2014-09-18 21:03:04 +01:00
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
2017-04-17 15:20:31 +08:00
|
|
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
2014-09-18 21:03:04 +01:00
|
|
|
cls.AllowOptimisations = true
|
2014-05-09 18:21:31 +01:00
|
|
|
cls.Fit(trainingData)
|
2016-10-10 19:45:20 -07:00
|
|
|
predictions, err := cls.Predict(testingData)
|
|
|
|
So(err, ShouldBeNil)
|
2014-08-02 16:22:14 +01:00
|
|
|
So(predictions, ShouldNotEqual, nil)
|
2014-05-03 23:08:43 +01:00
|
|
|
|
|
|
|
Convey("When predicting the label for our first vector", func() {
|
2014-08-02 16:22:14 +01:00
|
|
|
result := base.GetClass(predictions, 0)
|
2014-05-03 23:08:43 +01:00
|
|
|
Convey("The result should be 'blue", func() {
|
|
|
|
So(result, ShouldEqual, "blue")
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
2014-08-19 06:22:10 +00:00
|
|
|
Convey("When predicting the label for our second vector", func() {
|
2014-08-02 16:22:14 +01:00
|
|
|
result2 := base.GetClass(predictions, 1)
|
2014-05-03 23:08:43 +01:00
|
|
|
Convey("The result should be 'red", func() {
|
|
|
|
So(result2, ShouldEqual, "red")
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
2016-07-11 23:05:03 +01:00
|
|
|
|
2016-07-11 23:27:04 +01:00
|
|
|
func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
|
2016-07-11 23:05:03 +01:00
|
|
|
Convey("Given two basically identical files...", t, func() {
|
|
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
|
|
|
So(err, ShouldBeNil)
|
2016-07-11 23:16:18 +01:00
|
|
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
2016-07-11 23:05:03 +01:00
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
2017-04-17 15:20:31 +08:00
|
|
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
2016-07-11 23:05:03 +01:00
|
|
|
cls.Fit(trainingData)
|
2016-10-10 19:45:20 -07:00
|
|
|
predictions, err := cls.Predict(testingData)
|
|
|
|
So(err, ShouldBeNil)
|
2016-07-11 23:05:03 +01:00
|
|
|
So(predictions, ShouldNotBeNil)
|
|
|
|
})
|
|
|
|
}
|
2016-07-11 23:27:04 +01:00
|
|
|
|
|
|
|
func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
|
|
|
Convey("Given two basically identical files...", t, func() {
|
|
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
2017-04-17 15:20:31 +08:00
|
|
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
2016-07-11 23:27:04 +01:00
|
|
|
cls.Fit(trainingData)
|
2016-10-10 19:45:20 -07:00
|
|
|
predictions, err := cls.Predict(testingData)
|
|
|
|
So(err, ShouldBeNil)
|
2016-07-11 23:27:04 +01:00
|
|
|
So(predictions, ShouldNotBeNil)
|
|
|
|
})
|
|
|
|
}
|
2016-10-10 19:45:20 -07:00
|
|
|
|
|
|
|
func TestKnnClassifierImplementsClassifier(t *testing.T) {
|
2017-04-17 15:20:31 +08:00
|
|
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
2016-10-10 19:45:20 -07:00
|
|
|
var c base.Classifier = cls
|
|
|
|
if len(c.String()) < 1 {
|
|
|
|
t.Fail()
|
|
|
|
}
|
|
|
|
}
|