1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/knn/knn_test.go

109 lines
3.2 KiB
Go
Raw Normal View History

2014-05-03 23:08:43 +01:00
package knn
import (
"testing"
"github.com/sjwhitworth/golearn/base"
2014-05-03 23:08:43 +01:00
. "github.com/smartystreets/goconvey/convey"
)
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
2014-05-03 23:08:43 +01:00
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
2017-04-17 15:20:31 +08:00
cls := NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}
func TestKnnClassifierWithOptimisations(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
2017-04-17 15:20:31 +08:00
cls := NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = true
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
2014-08-02 16:22:14 +01:00
So(predictions, ShouldNotEqual, nil)
2014-05-03 23:08:43 +01:00
Convey("When predicting the label for our first vector", func() {
2014-08-02 16:22:14 +01:00
result := base.GetClass(predictions, 0)
2014-05-03 23:08:43 +01:00
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
2014-08-02 16:22:14 +01:00
result2 := base.GetClass(predictions, 1)
2014-05-03 23:08:43 +01:00
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}
2016-07-11 23:27:04 +01:00
func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
Convey("Given two basically identical files...", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
So(err, ShouldBeNil)
2017-04-17 15:20:31 +08:00
cls := NewKnnClassifier("euclidean", "linear", 2)
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotBeNil)
})
}
2016-07-11 23:27:04 +01:00
func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
Convey("Given two basically identical files...", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
So(err, ShouldBeNil)
2017-04-17 15:20:31 +08:00
cls := NewKnnClassifier("euclidean", "linear", 2)
2016-07-11 23:27:04 +01:00
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
2016-07-11 23:27:04 +01:00
So(predictions, ShouldNotBeNil)
})
}
func TestKnnClassifierImplementsClassifier(t *testing.T) {
2017-04-17 15:20:31 +08:00
cls := NewKnnClassifier("euclidean", "linear", 2)
var c base.Classifier = cls
if len(c.String()) < 1 {
t.Fail()
}
}