From 7b765a2f1889532672d12401cba85397dcb48f42 Mon Sep 17 00:00:00 2001 From: FrozenKP Date: Mon, 17 Apr 2017 15:20:31 +0800 Subject: [PATCH] add kdtree to knn --- README.md | 2 +- examples/knnclassifier/knnclassifier_iris.go | 2 +- knn/knn.go | 81 +++++++++++++++----- knn/knn_bench_test.go | 4 +- knn/knn_kdtree_test.go | 77 +++++++++++++++++++ knn/knn_test.go | 10 +-- 6 files changed, 146 insertions(+), 30 deletions(-) create mode 100644 knn/knn_kdtree_test.go diff --git a/README.md b/README.md index 7d3de58..f82f622 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ func main() { fmt.Println(rawData) //Initialises a new KNN classifier - cls := knn.NewKnnClassifier("euclidean", 2) + cls := knn.NewKnnClassifier("euclidean", "linear", 2) //Do a training-test split trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50) diff --git a/examples/knnclassifier/knnclassifier_iris.go b/examples/knnclassifier/knnclassifier_iris.go index 592e3ee..0661ce9 100644 --- a/examples/knnclassifier/knnclassifier_iris.go +++ b/examples/knnclassifier/knnclassifier_iris.go @@ -15,7 +15,7 @@ func main() { } //Initialises a new KNN classifier - cls := knn.NewKnnClassifier("euclidean", 2) + cls := knn.NewKnnClassifier("euclidean", "linear", 2) //Do a training-test split trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50) diff --git a/knn/knn.go b/knn/knn.go index 308c45a..0c3aab7 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -10,26 +10,30 @@ import ( "github.com/gonum/matrix" "github.com/gonum/matrix/mat64" "github.com/sjwhitworth/golearn/base" + "github.com/sjwhitworth/golearn/kdtree" "github.com/sjwhitworth/golearn/metrics/pairwise" "github.com/sjwhitworth/golearn/utilities" ) -// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function. -// The accepted distance functions at this time are 'euclidean' and 'manhattan'. +// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, searching algorithm, and a distance function. +// The accepted distance functions at this time are 'euclidean', 'manhattan', and 'cosine'. +// The accepted searching algorithm here are 'linear', and 'kdtree'. // Optimisations only occur when things are identically group into identical // AttributeGroups, which don't include the class variable, in the same order. type KNNClassifier struct { base.BaseEstimator TrainingData base.FixedDataGrid DistanceFunc string + Algorithm string NearestNeighbours int AllowOptimisations bool } // NewKnnClassifier returns a new classifier -func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier { +func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier { KNN := KNNClassifier{} KNN.DistanceFunc = distfunc + KNN.Algorithm = algorithm KNN.NearestNeighbours = neighbours KNN.AllowOptimisations = true return &KNN @@ -105,6 +109,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, default: return nil, errors.New("unsupported distance function") } + + // Check what searching algorith, we are using + if KNN.Algorithm != "linear" && KNN.Algorithm != "kdtree" { + return nil, errors.New("unsupported searching algorithm") + } + // Check Compatibility allAttrs := base.CheckCompatible(what, KNN.TrainingData) if allAttrs == nil { @@ -113,7 +123,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, } // Use optimised version if permitted - if KNN.AllowOptimisations { + if KNN.Algorithm == "linear" && KNN.AllowOptimisations { if KNN.DistanceFunc == "euclidean" { if KNN.canUseOptimisations(what) { return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil @@ -156,6 +166,25 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, _, maxRow := what.Size() curRow := 0 + // build kdtree if algorithm is 'kdtree' + kd := kdtree.New() + if KNN.Algorithm == "kdtree" { + buildData := make([][]float64, 0) + KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { + oneData := make([]float64, len(allNumericAttrs)) + // Read the float values out + for i, _ := range allNumericAttrs { + oneData[i] = base.UnpackBytesToFloat(trainRow[i]) + } + buildData = append(buildData, oneData) + return true, nil + }) + + err := kd.Build(buildData) + if err != nil { + return nil, err + } + } // Iterate over all outer rows what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) { @@ -171,25 +200,35 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, predMat := utilities.FloatsToMatrix(predRowBuf) - // Find the closest match in the training data - KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { - // Read the float values out - for i, _ := range allNumericAttrs { - trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i]) + switch KNN.Algorithm { + case "linear": + // Find the closest match in the training data + KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { + // Read the float values out + for i, _ := range allNumericAttrs { + trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i]) + } + + // Compute the distance + trainMat := utilities.FloatsToMatrix(trainRowBuf) + distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat) + return true, nil + }) + + sorted := utilities.SortIntMap(distances) + values := sorted[:KNN.NearestNeighbours] + maxClass := KNN.vote(maxmap, values) + base.SetClass(ret, predRowNo, maxClass) + + case "kdtree": + values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf) + if err != nil { + return false, err } + maxClass := KNN.vote(maxmap, values) + base.SetClass(ret, predRowNo, maxClass) + } - // Compute the distance - trainMat := utilities.FloatsToMatrix(trainRowBuf) - distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat) - return true, nil - }) - - sorted := utilities.SortIntMap(distances) - values := sorted[:KNN.NearestNeighbours] - - maxClass := KNN.vote(maxmap, values) - - base.SetClass(ret, predRowNo, maxClass) return true, nil }) diff --git a/knn/knn_bench_test.go b/knn/knn_bench_test.go index 4d49631..85d4141 100644 --- a/knn/knn_bench_test.go +++ b/knn/knn_bench_test.go @@ -43,7 +43,7 @@ func readMnist() (*base.DenseInstances, *base.DenseInstances) { func BenchmarkKNNWithOpts(b *testing.B) { // Load train, test := readMnist() - cls := NewKnnClassifier("euclidean", 1) + cls := NewKnnClassifier("euclidean", "linear", 1) cls.AllowOptimisations = true cls.Fit(train) predictions, err := cls.Predict(test) @@ -61,7 +61,7 @@ func BenchmarkKNNWithOpts(b *testing.B) { func BenchmarkKNNWithNoOpts(b *testing.B) { // Load train, test := readMnist() - cls := NewKnnClassifier("euclidean", 1) + cls := NewKnnClassifier("euclidean", "linear", 1) cls.AllowOptimisations = false cls.Fit(train) predictions, err := cls.Predict(test) diff --git a/knn/knn_kdtree_test.go b/knn/knn_kdtree_test.go new file mode 100644 index 0000000..4ee79b9 --- /dev/null +++ b/knn/knn_kdtree_test.go @@ -0,0 +1,77 @@ +package knn + +import ( + "testing" + + "github.com/sjwhitworth/golearn/base" + . "github.com/smartystreets/goconvey/convey" +) + +func TestKnnClassifierWithoutOptimisationsWithKdtree(t *testing.T) { + Convey("Given labels, a classifier and data", t, func() { + trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false) + So(err, ShouldBeNil) + + testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) + So(err, ShouldBeNil) + + cls := NewKnnClassifier("euclidean", "kdtree", 2) + cls.AllowOptimisations = false + cls.Fit(trainingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) + So(predictions, ShouldNotEqual, nil) + + Convey("When predicting the label for our first vector", func() { + result := base.GetClass(predictions, 0) + Convey("The result should be 'blue", func() { + So(result, ShouldEqual, "blue") + }) + }) + + Convey("When predicting the label for our second vector", func() { + result2 := base.GetClass(predictions, 1) + Convey("The result should be 'red", func() { + So(result2, ShouldEqual, "red") + }) + }) + }) +} + +func TestKnnClassifierWithTemplatedInstances1WithKdtree(t *testing.T) { + Convey("Given two basically identical files...", t, func() { + trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true) + So(err, ShouldBeNil) + testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData) + So(err, ShouldBeNil) + + cls := NewKnnClassifier("euclidean", "kdtree", 2) + cls.Fit(trainingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) + So(predictions, ShouldNotBeNil) + }) +} + +func TestKnnClassifierWithTemplatedInstances1SubsetWithKdtree(t *testing.T) { + Convey("Given two basically identical files...", t, func() { + trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true) + So(err, ShouldBeNil) + testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData) + So(err, ShouldBeNil) + + cls := NewKnnClassifier("euclidean", "kdtree", 2) + cls.Fit(trainingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) + So(predictions, ShouldNotBeNil) + }) +} + +func TestKnnClassifierImplementsClassifierWithKdtree(t *testing.T) { + cls := NewKnnClassifier("euclidean", "kdtree", 2) + var c base.Classifier = cls + if len(c.String()) < 1 { + t.Fail() + } +} diff --git a/knn/knn_test.go b/knn/knn_test.go index 7b72453..6915c85 100644 --- a/knn/knn_test.go +++ b/knn/knn_test.go @@ -15,7 +15,7 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) { testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) So(err, ShouldBeNil) - cls := NewKnnClassifier("euclidean", 2) + cls := NewKnnClassifier("euclidean", "linear", 2) cls.AllowOptimisations = false cls.Fit(trainingData) predictions, err := cls.Predict(testingData) @@ -46,7 +46,7 @@ func TestKnnClassifierWithOptimisations(t *testing.T) { testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) So(err, ShouldBeNil) - cls := NewKnnClassifier("euclidean", 2) + cls := NewKnnClassifier("euclidean", "linear", 2) cls.AllowOptimisations = true cls.Fit(trainingData) predictions, err := cls.Predict(testingData) @@ -76,7 +76,7 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) { testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData) So(err, ShouldBeNil) - cls := NewKnnClassifier("euclidean", 2) + cls := NewKnnClassifier("euclidean", "linear", 2) cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(err, ShouldBeNil) @@ -91,7 +91,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) { testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData) So(err, ShouldBeNil) - cls := NewKnnClassifier("euclidean", 2) + cls := NewKnnClassifier("euclidean", "linear", 2) cls.Fit(trainingData) predictions, err := cls.Predict(testingData) So(err, ShouldBeNil) @@ -100,7 +100,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) { } func TestKnnClassifierImplementsClassifier(t *testing.T) { - cls := NewKnnClassifier("euclidean", 2) + cls := NewKnnClassifier("euclidean", "linear", 2) var c base.Classifier = cls if len(c.String()) < 1 { t.Fail()