mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
95 lines
2.9 KiB
Go
95 lines
2.9 KiB
Go
package knn
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
)
|
|
|
|
func TestKnnClassifierCov(t *testing.T) {
|
|
Convey("Test predict", t, func() {
|
|
Convey("distance function", func() {
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("use euclidean", func() {
|
|
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(err, ShouldBeNil)
|
|
So(predictions, ShouldNotEqual, nil)
|
|
result := base.GetClass(predictions, 0)
|
|
So(result, ShouldEqual, "blue")
|
|
})
|
|
|
|
Convey("use manhattan", func() {
|
|
cls := NewKnnClassifier("manhattan", "kdtree", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(err, ShouldBeNil)
|
|
So(predictions, ShouldNotEqual, nil)
|
|
result := base.GetClass(predictions, 0)
|
|
So(result, ShouldEqual, "blue")
|
|
})
|
|
|
|
Convey("use cosine", func() {
|
|
cls := NewKnnClassifier("cosine", "kdtree", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(err, ShouldBeNil)
|
|
So(predictions, ShouldNotEqual, nil)
|
|
result := base.GetClass(predictions, 0)
|
|
So(result, ShouldEqual, "blue")
|
|
})
|
|
|
|
Convey("use undefined distance function", func() {
|
|
cls := NewKnnClassifier("abcd", "kdtree", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(predictions, ShouldBeNil)
|
|
So(err.Error(), ShouldEqual, "unsupported distance function")
|
|
})
|
|
})
|
|
|
|
Convey("searching algorithm", func() {
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("use undefined searching algorithm", func() {
|
|
cls := NewKnnClassifier("cosine", "abcd", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(predictions, ShouldBeNil)
|
|
So(err.Error(), ShouldEqual, "unsupported searching algorithm")
|
|
})
|
|
})
|
|
|
|
Convey("check features", func() {
|
|
Convey("use different dataset", func() {
|
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
|
So(err, ShouldBeNil)
|
|
testingData, err := base.ParseCSVToInstances("knn_test_2.csv", false)
|
|
So(err, ShouldBeNil)
|
|
cls := NewKnnClassifier("cosine", "linear", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions, err := cls.Predict(testingData)
|
|
So(predictions, ShouldBeNil)
|
|
So(err.Error(), ShouldEqual, "attributes not compatible")
|
|
})
|
|
})
|
|
})
|
|
}
|