1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/knn/knn_test.go
Richard Townsend 527c6476e1 Optimised version of KNN for Euclidean distances
This patch also:
   * Completes removal of the edf/ package
   * Corrects an erroneous print statement
   * Introduces two new CSV functions
      * ParseCSVToInstancesTemplated makes sure that
        reading a second CSV file maintains strict Attribute
        compatibility with an existing DenseInstances
      * ParseCSVToInstancesWithAttributeGroups gives more control
        over where Attributes end up in memory, important for
        gaining predictable control over the KNN optimisation
      * Decouples BinaryAttributeGroup from FixedAttributeGroup for
        better casting support
2014-09-30 23:10:22 +01:00

68 lines
1.9 KiB
Go

package knn
import (
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}
func TestKnnClassifierWithOptimisations(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2)
cls.AllowOptimisations = true
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}