1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/knn/knn_cov_test.go
2018-06-16 22:14:18 +08:00

95 lines
2.9 KiB
Go

package knn
import (
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
)
func TestKnnClassifierCov(t *testing.T) {
Convey("Test predict", t, func() {
Convey("distance function", func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
Convey("use euclidean", func() {
cls := NewKnnClassifier("euclidean", "kdtree", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
result := base.GetClass(predictions, 0)
So(result, ShouldEqual, "blue")
})
Convey("use manhattan", func() {
cls := NewKnnClassifier("manhattan", "kdtree", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
result := base.GetClass(predictions, 0)
So(result, ShouldEqual, "blue")
})
Convey("use cosine", func() {
cls := NewKnnClassifier("cosine", "kdtree", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
result := base.GetClass(predictions, 0)
So(result, ShouldEqual, "blue")
})
Convey("use undefined distance function", func() {
cls := NewKnnClassifier("abcd", "kdtree", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(predictions, ShouldBeNil)
So(err.Error(), ShouldEqual, "unsupported distance function")
})
})
Convey("searching algorithm", func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
Convey("use undefined searching algorithm", func() {
cls := NewKnnClassifier("cosine", "abcd", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(predictions, ShouldBeNil)
So(err.Error(), ShouldEqual, "unsupported searching algorithm")
})
})
Convey("check features", func() {
Convey("use different dataset", func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_2.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("cosine", "linear", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(predictions, ShouldBeNil)
So(err.Error(), ShouldEqual, "attributes not compatible")
})
})
})
}