mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00

This patch also: * Completes removal of the edf/ package * Corrects an erroneous print statement * Introduces two new CSV functions * ParseCSVToInstancesTemplated makes sure that reading a second CSV file maintains strict Attribute compatibility with an existing DenseInstances * ParseCSVToInstancesWithAttributeGroups gives more control over where Attributes end up in memory, important for gaining predictable control over the KNN optimisation * Decouples BinaryAttributeGroup from FixedAttributeGroup for better casting support
68 lines
1.9 KiB
Go
68 lines
1.9 KiB
Go
package knn
|
|
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"testing"
|
|
)
|
|
|
|
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
|
Convey("Given labels, a classifier and data", t, func() {
|
|
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
cls := NewKnnClassifier("euclidean", 2)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(trainingData)
|
|
predictions := cls.Predict(testingData)
|
|
So(predictions, ShouldNotEqual, nil)
|
|
|
|
Convey("When predicting the label for our first vector", func() {
|
|
result := base.GetClass(predictions, 0)
|
|
Convey("The result should be 'blue", func() {
|
|
So(result, ShouldEqual, "blue")
|
|
})
|
|
})
|
|
|
|
Convey("When predicting the label for our second vector", func() {
|
|
result2 := base.GetClass(predictions, 1)
|
|
Convey("The result should be 'red", func() {
|
|
So(result2, ShouldEqual, "red")
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
|
Convey("Given labels, a classifier and data", t, func() {
|
|
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
|
|
So(err, ShouldBeNil)
|
|
|
|
cls := NewKnnClassifier("euclidean", 2)
|
|
cls.AllowOptimisations = true
|
|
cls.Fit(trainingData)
|
|
predictions := cls.Predict(testingData)
|
|
So(predictions, ShouldNotEqual, nil)
|
|
|
|
Convey("When predicting the label for our first vector", func() {
|
|
result := base.GetClass(predictions, 0)
|
|
Convey("The result should be 'blue", func() {
|
|
So(result, ShouldEqual, "blue")
|
|
})
|
|
})
|
|
|
|
Convey("When predicting the label for our second vector", func() {
|
|
result2 := base.GetClass(predictions, 1)
|
|
Convey("The result should be 'red", func() {
|
|
So(result2, ShouldEqual, "red")
|
|
})
|
|
})
|
|
})
|
|
}
|