1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

add kdtree to knn

This commit is contained in:
FrozenKP 2017-04-17 15:20:31 +08:00
parent 041b0b2590
commit 7b765a2f18
6 changed files with 146 additions and 30 deletions

View File

@ -48,7 +48,7 @@ func main() {
fmt.Println(rawData) fmt.Println(rawData)
//Initialises a new KNN classifier //Initialises a new KNN classifier
cls := knn.NewKnnClassifier("euclidean", 2) cls := knn.NewKnnClassifier("euclidean", "linear", 2)
//Do a training-test split //Do a training-test split
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50) trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)

View File

@ -15,7 +15,7 @@ func main() {
} }
//Initialises a new KNN classifier //Initialises a new KNN classifier
cls := knn.NewKnnClassifier("euclidean", 2) cls := knn.NewKnnClassifier("euclidean", "linear", 2)
//Do a training-test split //Do a training-test split
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50) trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)

View File

@ -10,26 +10,30 @@ import (
"github.com/gonum/matrix" "github.com/gonum/matrix"
"github.com/gonum/matrix/mat64" "github.com/gonum/matrix/mat64"
"github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/kdtree"
"github.com/sjwhitworth/golearn/metrics/pairwise" "github.com/sjwhitworth/golearn/metrics/pairwise"
"github.com/sjwhitworth/golearn/utilities" "github.com/sjwhitworth/golearn/utilities"
) )
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function. // A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, searching algorithm, and a distance function.
// The accepted distance functions at this time are 'euclidean' and 'manhattan'. // The accepted distance functions at this time are 'euclidean', 'manhattan', and 'cosine'.
// The accepted searching algorithm here are 'linear', and 'kdtree'.
// Optimisations only occur when things are identically group into identical // Optimisations only occur when things are identically group into identical
// AttributeGroups, which don't include the class variable, in the same order. // AttributeGroups, which don't include the class variable, in the same order.
type KNNClassifier struct { type KNNClassifier struct {
base.BaseEstimator base.BaseEstimator
TrainingData base.FixedDataGrid TrainingData base.FixedDataGrid
DistanceFunc string DistanceFunc string
Algorithm string
NearestNeighbours int NearestNeighbours int
AllowOptimisations bool AllowOptimisations bool
} }
// NewKnnClassifier returns a new classifier // NewKnnClassifier returns a new classifier
func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier { func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier {
KNN := KNNClassifier{} KNN := KNNClassifier{}
KNN.DistanceFunc = distfunc KNN.DistanceFunc = distfunc
KNN.Algorithm = algorithm
KNN.NearestNeighbours = neighbours KNN.NearestNeighbours = neighbours
KNN.AllowOptimisations = true KNN.AllowOptimisations = true
return &KNN return &KNN
@ -105,6 +109,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
default: default:
return nil, errors.New("unsupported distance function") return nil, errors.New("unsupported distance function")
} }
// Check what searching algorith, we are using
if KNN.Algorithm != "linear" && KNN.Algorithm != "kdtree" {
return nil, errors.New("unsupported searching algorithm")
}
// Check Compatibility // Check Compatibility
allAttrs := base.CheckCompatible(what, KNN.TrainingData) allAttrs := base.CheckCompatible(what, KNN.TrainingData)
if allAttrs == nil { if allAttrs == nil {
@ -113,7 +123,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
} }
// Use optimised version if permitted // Use optimised version if permitted
if KNN.AllowOptimisations { if KNN.Algorithm == "linear" && KNN.AllowOptimisations {
if KNN.DistanceFunc == "euclidean" { if KNN.DistanceFunc == "euclidean" {
if KNN.canUseOptimisations(what) { if KNN.canUseOptimisations(what) {
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
@ -156,6 +166,25 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
_, maxRow := what.Size() _, maxRow := what.Size()
curRow := 0 curRow := 0
// build kdtree if algorithm is 'kdtree'
kd := kdtree.New()
if KNN.Algorithm == "kdtree" {
buildData := make([][]float64, 0)
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
oneData := make([]float64, len(allNumericAttrs))
// Read the float values out
for i, _ := range allNumericAttrs {
oneData[i] = base.UnpackBytesToFloat(trainRow[i])
}
buildData = append(buildData, oneData)
return true, nil
})
err := kd.Build(buildData)
if err != nil {
return nil, err
}
}
// Iterate over all outer rows // Iterate over all outer rows
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) { what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
@ -171,6 +200,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
predMat := utilities.FloatsToMatrix(predRowBuf) predMat := utilities.FloatsToMatrix(predRowBuf)
switch KNN.Algorithm {
case "linear":
// Find the closest match in the training data // Find the closest match in the training data
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) { KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
// Read the float values out // Read the float values out
@ -186,10 +217,18 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
sorted := utilities.SortIntMap(distances) sorted := utilities.SortIntMap(distances)
values := sorted[:KNN.NearestNeighbours] values := sorted[:KNN.NearestNeighbours]
maxClass := KNN.vote(maxmap, values) maxClass := KNN.vote(maxmap, values)
base.SetClass(ret, predRowNo, maxClass) base.SetClass(ret, predRowNo, maxClass)
case "kdtree":
values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
if err != nil {
return false, err
}
maxClass := KNN.vote(maxmap, values)
base.SetClass(ret, predRowNo, maxClass)
}
return true, nil return true, nil
}) })

View File

@ -43,7 +43,7 @@ func readMnist() (*base.DenseInstances, *base.DenseInstances) {
func BenchmarkKNNWithOpts(b *testing.B) { func BenchmarkKNNWithOpts(b *testing.B) {
// Load // Load
train, test := readMnist() train, test := readMnist()
cls := NewKnnClassifier("euclidean", 1) cls := NewKnnClassifier("euclidean", "linear", 1)
cls.AllowOptimisations = true cls.AllowOptimisations = true
cls.Fit(train) cls.Fit(train)
predictions, err := cls.Predict(test) predictions, err := cls.Predict(test)
@ -61,7 +61,7 @@ func BenchmarkKNNWithOpts(b *testing.B) {
func BenchmarkKNNWithNoOpts(b *testing.B) { func BenchmarkKNNWithNoOpts(b *testing.B) {
// Load // Load
train, test := readMnist() train, test := readMnist()
cls := NewKnnClassifier("euclidean", 1) cls := NewKnnClassifier("euclidean", "linear", 1)
cls.AllowOptimisations = false cls.AllowOptimisations = false
cls.Fit(train) cls.Fit(train)
predictions, err := cls.Predict(test) predictions, err := cls.Predict(test)

77
knn/knn_kdtree_test.go Normal file
View File

@ -0,0 +1,77 @@
package knn
import (
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
)
func TestKnnClassifierWithoutOptimisationsWithKdtree(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", "kdtree", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}
func TestKnnClassifierWithTemplatedInstances1WithKdtree(t *testing.T) {
Convey("Given two basically identical files...", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", "kdtree", 2)
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotBeNil)
})
}
func TestKnnClassifierWithTemplatedInstances1SubsetWithKdtree(t *testing.T) {
Convey("Given two basically identical files...", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", "kdtree", 2)
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotBeNil)
})
}
func TestKnnClassifierImplementsClassifierWithKdtree(t *testing.T) {
cls := NewKnnClassifier("euclidean", "kdtree", 2)
var c base.Classifier = cls
if len(c.String()) < 1 {
t.Fail()
}
}

View File

@ -15,7 +15,7 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil) So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2) cls := NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = false cls.AllowOptimisations = false
cls.Fit(trainingData) cls.Fit(trainingData)
predictions, err := cls.Predict(testingData) predictions, err := cls.Predict(testingData)
@ -46,7 +46,7 @@ func TestKnnClassifierWithOptimisations(t *testing.T) {
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil) So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2) cls := NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = true cls.AllowOptimisations = true
cls.Fit(trainingData) cls.Fit(trainingData)
predictions, err := cls.Predict(testingData) predictions, err := cls.Predict(testingData)
@ -76,7 +76,7 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData) testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
So(err, ShouldBeNil) So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2) cls := NewKnnClassifier("euclidean", "linear", 2)
cls.Fit(trainingData) cls.Fit(trainingData)
predictions, err := cls.Predict(testingData) predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil) So(err, ShouldBeNil)
@ -91,7 +91,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData) testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
So(err, ShouldBeNil) So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2) cls := NewKnnClassifier("euclidean", "linear", 2)
cls.Fit(trainingData) cls.Fit(trainingData)
predictions, err := cls.Predict(testingData) predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil) So(err, ShouldBeNil)
@ -100,7 +100,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
} }
func TestKnnClassifierImplementsClassifier(t *testing.T) { func TestKnnClassifierImplementsClassifier(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2) cls := NewKnnClassifier("euclidean", "linear", 2)
var c base.Classifier = cls var c base.Classifier = cls
if len(c.String()) < 1 { if len(c.String()) < 1 {
t.Fail() t.Fail()