package knn import ( "testing" "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" ) func TestKnnClassifierCov(t *testing.T) { Convey("Test predict", t, func() { Convey("distance function", func(){ trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false) So(err, ShouldBeNil) testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) So(err, ShouldBeNil) Convey("use euclidean", func(){ cls := NewKnnClassifier("euclidean", "kdtree", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(err, ShouldBeNil) So(predictions, ShouldNotEqual, nil) result := base.GetClass(predictions, 0) So(result, ShouldEqual, "blue") }) Convey("use manhattan", func(){ cls := NewKnnClassifier("manhattan", "kdtree", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(err, ShouldBeNil) So(predictions, ShouldNotEqual, nil) result := base.GetClass(predictions, 0) So(result, ShouldEqual, "blue") }) Convey("use cosine", func(){ cls := NewKnnClassifier("cosine", "kdtree", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(err, ShouldBeNil) So(predictions, ShouldNotEqual, nil) result := base.GetClass(predictions, 0) So(result, ShouldEqual, "blue") }) Convey("use undefined distance function", func(){ cls := NewKnnClassifier("abcd", "kdtree", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(predictions, ShouldBeNil) So(err.Error(), ShouldEqual, "unsupported distance function") }) }) Convey("searching algorithm", func(){ trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false) So(err, ShouldBeNil) testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) So(err, ShouldBeNil) Convey("use undefined searching algorithm", func(){ cls := NewKnnClassifier("cosine", "abcd", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(predictions, ShouldBeNil) So(err.Error(), ShouldEqual, "unsupported searching algorithm") }) }) Convey("check features", func(){ Convey("use different dataset", func(){ trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false) So(err, ShouldBeNil) testingData, err := base.ParseCSVToInstances("knn_test_2.csv", false) So(err, ShouldBeNil) cls := NewKnnClassifier("cosine", "linear", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(predictions, ShouldBeNil) So(err.Error(), ShouldEqual, "attributes not compatible") }) }) }) }