mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
add kdtree to knn
This commit is contained in:
parent
041b0b2590
commit
7b765a2f18
@ -48,7 +48,7 @@ func main() {
|
|||||||
fmt.Println(rawData)
|
fmt.Println(rawData)
|
||||||
|
|
||||||
//Initialises a new KNN classifier
|
//Initialises a new KNN classifier
|
||||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
|
||||||
|
|
||||||
//Do a training-test split
|
//Do a training-test split
|
||||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||||
|
@ -15,7 +15,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//Initialises a new KNN classifier
|
//Initialises a new KNN classifier
|
||||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
cls := knn.NewKnnClassifier("euclidean", "linear", 2)
|
||||||
|
|
||||||
//Do a training-test split
|
//Do a training-test split
|
||||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||||
|
51
knn/knn.go
51
knn/knn.go
@ -10,26 +10,30 @@ import (
|
|||||||
"github.com/gonum/matrix"
|
"github.com/gonum/matrix"
|
||||||
"github.com/gonum/matrix/mat64"
|
"github.com/gonum/matrix/mat64"
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
"github.com/sjwhitworth/golearn/kdtree"
|
||||||
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
||||||
"github.com/sjwhitworth/golearn/utilities"
|
"github.com/sjwhitworth/golearn/utilities"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, and a distance function.
|
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, searching algorithm, and a distance function.
|
||||||
// The accepted distance functions at this time are 'euclidean' and 'manhattan'.
|
// The accepted distance functions at this time are 'euclidean', 'manhattan', and 'cosine'.
|
||||||
|
// The accepted searching algorithm here are 'linear', and 'kdtree'.
|
||||||
// Optimisations only occur when things are identically group into identical
|
// Optimisations only occur when things are identically group into identical
|
||||||
// AttributeGroups, which don't include the class variable, in the same order.
|
// AttributeGroups, which don't include the class variable, in the same order.
|
||||||
type KNNClassifier struct {
|
type KNNClassifier struct {
|
||||||
base.BaseEstimator
|
base.BaseEstimator
|
||||||
TrainingData base.FixedDataGrid
|
TrainingData base.FixedDataGrid
|
||||||
DistanceFunc string
|
DistanceFunc string
|
||||||
|
Algorithm string
|
||||||
NearestNeighbours int
|
NearestNeighbours int
|
||||||
AllowOptimisations bool
|
AllowOptimisations bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKnnClassifier returns a new classifier
|
// NewKnnClassifier returns a new classifier
|
||||||
func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
|
func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier {
|
||||||
KNN := KNNClassifier{}
|
KNN := KNNClassifier{}
|
||||||
KNN.DistanceFunc = distfunc
|
KNN.DistanceFunc = distfunc
|
||||||
|
KNN.Algorithm = algorithm
|
||||||
KNN.NearestNeighbours = neighbours
|
KNN.NearestNeighbours = neighbours
|
||||||
KNN.AllowOptimisations = true
|
KNN.AllowOptimisations = true
|
||||||
return &KNN
|
return &KNN
|
||||||
@ -105,6 +109,12 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported distance function")
|
return nil, errors.New("unsupported distance function")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check what searching algorith, we are using
|
||||||
|
if KNN.Algorithm != "linear" && KNN.Algorithm != "kdtree" {
|
||||||
|
return nil, errors.New("unsupported searching algorithm")
|
||||||
|
}
|
||||||
|
|
||||||
// Check Compatibility
|
// Check Compatibility
|
||||||
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
||||||
if allAttrs == nil {
|
if allAttrs == nil {
|
||||||
@ -113,7 +123,7 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use optimised version if permitted
|
// Use optimised version if permitted
|
||||||
if KNN.AllowOptimisations {
|
if KNN.Algorithm == "linear" && KNN.AllowOptimisations {
|
||||||
if KNN.DistanceFunc == "euclidean" {
|
if KNN.DistanceFunc == "euclidean" {
|
||||||
if KNN.canUseOptimisations(what) {
|
if KNN.canUseOptimisations(what) {
|
||||||
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
|
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
|
||||||
@ -156,6 +166,25 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
|||||||
_, maxRow := what.Size()
|
_, maxRow := what.Size()
|
||||||
curRow := 0
|
curRow := 0
|
||||||
|
|
||||||
|
// build kdtree if algorithm is 'kdtree'
|
||||||
|
kd := kdtree.New()
|
||||||
|
if KNN.Algorithm == "kdtree" {
|
||||||
|
buildData := make([][]float64, 0)
|
||||||
|
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||||
|
oneData := make([]float64, len(allNumericAttrs))
|
||||||
|
// Read the float values out
|
||||||
|
for i, _ := range allNumericAttrs {
|
||||||
|
oneData[i] = base.UnpackBytesToFloat(trainRow[i])
|
||||||
|
}
|
||||||
|
buildData = append(buildData, oneData)
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
err := kd.Build(buildData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
// Iterate over all outer rows
|
// Iterate over all outer rows
|
||||||
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
||||||
|
|
||||||
@ -171,6 +200,8 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
|||||||
|
|
||||||
predMat := utilities.FloatsToMatrix(predRowBuf)
|
predMat := utilities.FloatsToMatrix(predRowBuf)
|
||||||
|
|
||||||
|
switch KNN.Algorithm {
|
||||||
|
case "linear":
|
||||||
// Find the closest match in the training data
|
// Find the closest match in the training data
|
||||||
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
||||||
// Read the float values out
|
// Read the float values out
|
||||||
@ -186,10 +217,18 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid,
|
|||||||
|
|
||||||
sorted := utilities.SortIntMap(distances)
|
sorted := utilities.SortIntMap(distances)
|
||||||
values := sorted[:KNN.NearestNeighbours]
|
values := sorted[:KNN.NearestNeighbours]
|
||||||
|
|
||||||
maxClass := KNN.vote(maxmap, values)
|
maxClass := KNN.vote(maxmap, values)
|
||||||
|
|
||||||
base.SetClass(ret, predRowNo, maxClass)
|
base.SetClass(ret, predRowNo, maxClass)
|
||||||
|
|
||||||
|
case "kdtree":
|
||||||
|
values, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
maxClass := KNN.vote(maxmap, values)
|
||||||
|
base.SetClass(ret, predRowNo, maxClass)
|
||||||
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|
||||||
})
|
})
|
||||||
|
@ -43,7 +43,7 @@ func readMnist() (*base.DenseInstances, *base.DenseInstances) {
|
|||||||
func BenchmarkKNNWithOpts(b *testing.B) {
|
func BenchmarkKNNWithOpts(b *testing.B) {
|
||||||
// Load
|
// Load
|
||||||
train, test := readMnist()
|
train, test := readMnist()
|
||||||
cls := NewKnnClassifier("euclidean", 1)
|
cls := NewKnnClassifier("euclidean", "linear", 1)
|
||||||
cls.AllowOptimisations = true
|
cls.AllowOptimisations = true
|
||||||
cls.Fit(train)
|
cls.Fit(train)
|
||||||
predictions, err := cls.Predict(test)
|
predictions, err := cls.Predict(test)
|
||||||
@ -61,7 +61,7 @@ func BenchmarkKNNWithOpts(b *testing.B) {
|
|||||||
func BenchmarkKNNWithNoOpts(b *testing.B) {
|
func BenchmarkKNNWithNoOpts(b *testing.B) {
|
||||||
// Load
|
// Load
|
||||||
train, test := readMnist()
|
train, test := readMnist()
|
||||||
cls := NewKnnClassifier("euclidean", 1)
|
cls := NewKnnClassifier("euclidean", "linear", 1)
|
||||||
cls.AllowOptimisations = false
|
cls.AllowOptimisations = false
|
||||||
cls.Fit(train)
|
cls.Fit(train)
|
||||||
predictions, err := cls.Predict(test)
|
predictions, err := cls.Predict(test)
|
||||||
|
77
knn/knn_kdtree_test.go
Normal file
77
knn/knn_kdtree_test.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package knn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sjwhitworth/golearn/base"
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKnnClassifierWithoutOptimisationsWithKdtree(t *testing.T) {
|
||||||
|
Convey("Given labels, a classifier and data", t, func() {
|
||||||
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||||
|
cls.AllowOptimisations = false
|
||||||
|
cls.Fit(trainingData)
|
||||||
|
predictions, err := cls.Predict(testingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(predictions, ShouldNotEqual, nil)
|
||||||
|
|
||||||
|
Convey("When predicting the label for our first vector", func() {
|
||||||
|
result := base.GetClass(predictions, 0)
|
||||||
|
Convey("The result should be 'blue", func() {
|
||||||
|
So(result, ShouldEqual, "blue")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("When predicting the label for our second vector", func() {
|
||||||
|
result2 := base.GetClass(predictions, 1)
|
||||||
|
Convey("The result should be 'red", func() {
|
||||||
|
So(result2, ShouldEqual, "red")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnnClassifierWithTemplatedInstances1WithKdtree(t *testing.T) {
|
||||||
|
Convey("Given two basically identical files...", t, func() {
|
||||||
|
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||||
|
cls.Fit(trainingData)
|
||||||
|
predictions, err := cls.Predict(testingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(predictions, ShouldNotBeNil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnnClassifierWithTemplatedInstances1SubsetWithKdtree(t *testing.T) {
|
||||||
|
Convey("Given two basically identical files...", t, func() {
|
||||||
|
trainingData, err := base.ParseCSVToInstances("knn_train_2.csv", true)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||||
|
cls.Fit(trainingData)
|
||||||
|
predictions, err := cls.Predict(testingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(predictions, ShouldNotBeNil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnnClassifierImplementsClassifierWithKdtree(t *testing.T) {
|
||||||
|
cls := NewKnnClassifier("euclidean", "kdtree", 2)
|
||||||
|
var c base.Classifier = cls
|
||||||
|
if len(c.String()) < 1 {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
@ -15,7 +15,7 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
|||||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
cls := NewKnnClassifier("euclidean", 2)
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
cls.AllowOptimisations = false
|
cls.AllowOptimisations = false
|
||||||
cls.Fit(trainingData)
|
cls.Fit(trainingData)
|
||||||
predictions, err := cls.Predict(testingData)
|
predictions, err := cls.Predict(testingData)
|
||||||
@ -46,7 +46,7 @@ func TestKnnClassifierWithOptimisations(t *testing.T) {
|
|||||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
cls := NewKnnClassifier("euclidean", 2)
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
cls.AllowOptimisations = true
|
cls.AllowOptimisations = true
|
||||||
cls.Fit(trainingData)
|
cls.Fit(trainingData)
|
||||||
predictions, err := cls.Predict(testingData)
|
predictions, err := cls.Predict(testingData)
|
||||||
@ -76,7 +76,7 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
|
|||||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2.csv", true, trainingData)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
cls := NewKnnClassifier("euclidean", 2)
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
cls.Fit(trainingData)
|
cls.Fit(trainingData)
|
||||||
predictions, err := cls.Predict(testingData)
|
predictions, err := cls.Predict(testingData)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
@ -91,7 +91,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
|||||||
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
testingData, err := base.ParseCSVToTemplatedInstances("knn_test_2_subset.csv", true, trainingData)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
cls := NewKnnClassifier("euclidean", 2)
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
cls.Fit(trainingData)
|
cls.Fit(trainingData)
|
||||||
predictions, err := cls.Predict(testingData)
|
predictions, err := cls.Predict(testingData)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
@ -100,7 +100,7 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestKnnClassifierImplementsClassifier(t *testing.T) {
|
func TestKnnClassifierImplementsClassifier(t *testing.T) {
|
||||||
cls := NewKnnClassifier("euclidean", 2)
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
var c base.Classifier = cls
|
var c base.Classifier = cls
|
||||||
if len(c.String()) < 1 {
|
if len(c.String()) < 1 {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user