mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
add kdtree to knn
This commit is contained in:
parent
041b0b2590
commit
7b765a2f18
@ -48,7 +48,7 @@ func main() {
|
||||
fmt.Println(rawData)
|
||||
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
|
||||
|
||||
//Do a training-test split
|
||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||
|
@ -15,7 +15,7 @@ func main() {
|
||||
}
|
||||
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
|
||||
|
||||
//Do a training-test split
|
||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||
|
51
knn/knn.go
51
knn/knn.go
@ -10,26 +10,30 @@ import (
|
||||
"github.com/gonum/matrix"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/kdtree"
|
||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||
"github.com/sjwhitworth/golearn/utilities"
|
||||
)
|
||||
|
||||
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function.
|
||||
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
|
||||
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, searching algorithm, and a distance function.
|
||||
// The accepted distance functions at this time are 'euclidean', 'manhattan', and 'cosine'.
|
||||
// The accepted searching algorithm here are 'linear', and 'kdtree'.
|
||||
// Optimisations only occur when things are identically group into identical
|
||||
// AttributeGroups, which don't include the class variable, in the same order.
|
||||
type KNNClassifier struct {
|
||||
base.BaseEstimator
|
||||
TrainingData base.FixedDataGrid
|
||||
DistanceFunc string
|
||||
Algorithm string
|
||||
NearestNeighbours int
|
||||
AllowOptimisations bool
|
||||
}
|
||||
|
||||
// NewKnnClassifier returns a new classifier
|
||||
func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
|
||||
func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier {
|
||||
KNN := KNNClassifier{}
|
||||
KNN.DistanceFunc = distfunc
|
||||
KNN.Algorithm = algorithm
|
||||
KNN.NearestNeighbours = neighbours
|
||||
KNN.AllowOptimisations = true
|
||||
return &KNN
|
||||
@ -105,6 +109,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
default:
|
||||
return nil, errors.New("unsupported distance function")
|
||||
}
|
||||
|
||||
// Check what searching algorith, we are using
|
||||
if KNN.Algorithm != "linear" && KNN.Algorithm != "kdtree" {
|
||||
return nil, errors.New("unsupported searching algorithm")
|
||||
}
|
||||
|
||||
// Check Compatibility
|
||||
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
||||
if allAttrs == nil {
|
||||
@ -113,7 +123,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
}
|
||||
|
||||
// Use optimised version if permitted
|
||||
if KNN.AllowOptimisations {
|
||||
if KNN.Algorithm == "linear" && KNN.AllowOptimisations {
|
||||
if KNN.DistanceFunc == "euclidean" {
|
||||
if KNN.canUseOptimisations(what) {
|
||||
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
|
||||
@ -156,6 +166,25 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
_, maxRow := what.Size()
|
||||
curRow := 0
|
||||
|
||||
// build kdtree if algorithm is 'kdtree'
|
||||
kd := kdtree.New()
|
||||
if KNN.Algorithm == "kdtree" {
|
||||
buildData := make([][]float64, 0)
|
||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||
oneData := make([]float64, len(allNumericAttrs))
|
||||
// Read the float values out
|
||||
for i, _ := range allNumericAttrs {
|
||||
oneData[i] = base.UnpackBytesToFloat(trainRow[i])
|
||||
}
|
||||
buildData = append(buildData, oneData)
|
||||
return true, nil
|
||||
})
|
||||
|
||||
err := kd.Build(buildData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Iterate over all outer rows
|
||||
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
||||
|
||||
@ -171,6 +200,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
|
||||
predMat := utilities.FloatsToMatrix(predRowBuf)
|
||||
|
||||
switch KNN.Algorithm {
|
||||
case "linear":
|
||||
// Find the closest match in the training data
|
||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||
// Read the float values out
|
||||
@ -186,10 +217,18 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
||||
|
||||
sorted := utilities.SortIntMap(distances)
|
||||
values := sorted[:KNN.NearestNeighbours]
|
||||
|
||||
maxClass := KNN.vote(maxmap, values)
|
||||
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
|
||||
case "kdtree":
|
||||
values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
maxClass := KNN.vote(maxmap, values)
|
||||
base.SetClass(ret, predRowNo, maxClass)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
})
|
||||
|
@ -43,7 +43,7 @@ func readMnist() (*base.DenseInstances, *base.DenseInstances) {
|
||||
func BenchmarkKNNWithOpts(b *testing.B) {
|
||||
// Load
|
||||
train, test := readMnist()
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 1)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(train)
|
||||
predictions, err := cls.Predict(test)
|
||||
@ -61,7 +61,7 @@ func BenchmarkKNNWithOpts(b *testing.B) {
|
||||
func BenchmarkKNNWithNoOpts(b *testing.B) {
|
||||
// Load
|
||||
train, test := readMnist()
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 1)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(train)
|
||||
predictions, err := cls.Predict(test)
|
||||
|
77
knn/knn_kdtree_test.go
Normal file
77
knn/knn_kdtree_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package knn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestKnnClassifierWithoutOptimisationsWithKdtree(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
result := base.GetClass(predictions, 0)
|
||||
Convey("The result should be 'blue", func() {
|
||||
So(result, ShouldEqual, "blue")
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When predicting the label for our second vector", func() {
|
||||
result2 := base.GetClass(predictions, 1)
|
||||
Convey("The result should be 'red", func() {
|
||||
So(result2, ShouldEqual, "red")
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierWithTemplatedInstances1WithKdtree(t *testing.T) {
|
||||
Convey("Given two basically identical files...", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierWithTemplatedInstances1SubsetWithKdtree(t *testing.T) {
|
||||
Convey("Given two basically identical files...", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
||||
So(err, ShouldBeNil)
|
||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierImplementsClassifierWithKdtree(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||
var c base.Classifier = cls
|
||||
if len(c.String()) < 1 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
@ -15,7 +15,7 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
@ -46,7 +46,7 @@ func TestKnnClassifierWithOptimisations(t *testing.T) {
|
||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
@ -76,7 +76,7 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
|
||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
@ -91,7 +91,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
@ -100,7 +100,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestKnnClassifierImplementsClassifier(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
var c base.Classifier = cls
|
||||
if len(c.String()) < 1 {
|
||||
t.Fail()
|
||||
|
Loading…
x
Reference in New Issue
Block a user