mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Return errors allowing KNNClassifier to implement Classifier
This commit is contained in:
parent
72768fa9fc
commit
1d77ccdec6
@ -55,8 +55,10 @@ func main() {
|
||||
cls.Fit(trainData)
|
||||
|
||||
//Calculates the Euclidean distance and returns the most popular label
|
||||
predictions := cls.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
predictions, err := cls.Predict(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Prints precision/recall metrics
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/knn"
|
||||
@ -21,7 +22,10 @@ func main() {
|
||||
cls.Fit(trainData)
|
||||
|
||||
//Calculates the Euclidean distance and returns the most popular label
|
||||
predictions := cls.Predict(testData)
|
||||
predictions, err := cls.Predict(testData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(predictions)
|
||||
|
||||
// Prints precision/recall metrics
|
||||
|
19
knn/knn.go
19
knn/knn.go
@ -4,7 +4,9 @@
|
||||
package knn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gonum/matrix"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
@ -34,8 +36,9 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
|
||||
}
|
||||
|
||||
// Fit stores the training data for later
|
||||
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
|
||||
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) error {
|
||||
KNN.TrainingData = trainingData
|
||||
return nil
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
|
||||
@ -89,7 +92,7 @@ func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
|
||||
}
|
||||
|
||||
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
||||
// Check what distance function we are using
|
||||
var distanceFunc pairwise.PairwiseDistanceFunc
|
||||
switch KNN.DistanceFunc {
|
||||
@ -98,20 +101,20 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
case "manhattan":
|
||||
distanceFunc = pairwise.NewManhattan()
|
||||
default:
|
||||
panic("unsupported distance function")
|
||||
return nil, errors.New("unsupported distance function")
|
||||
}
|
||||
// Check Compatibility
|
||||
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
||||
if allAttrs == nil {
|
||||
// Don't have the same Attributes
|
||||
return nil
|
||||
return nil, errors.New("attributes not compatible")
|
||||
}
|
||||
|
||||
// Use optimised version if permitted
|
||||
if KNN.AllowOptimisations {
|
||||
if KNN.DistanceFunc == "euclidean" {
|
||||
if KNN.canUseOptimisations(what) {
|
||||
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances))
|
||||
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -189,7 +192,11 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
|
||||
|
||||
})
|
||||
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) String() string {
|
||||
return fmt.Sprintf("KNNClassifier(%s, %d)", KNN.DistanceFunc, KNN.NearestNeighbours)
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
|
||||
|
@ -2,9 +2,10 @@ package knn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func readMnist() (*base.DenseInstances, *base.DenseInstances) {
|
||||
@ -45,7 +46,10 @@ func BenchmarkKNNWithOpts(b *testing.B) {
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(train)
|
||||
predictions := cls.Predict(test)
|
||||
predictions, err := cls.Predict(test)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -60,7 +64,10 @@ func BenchmarkKNNWithNoOpts(b *testing.B) {
|
||||
cls := NewKnnClassifier("euclidean", 1)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(train)
|
||||
predictions := cls.Predict(test)
|
||||
predictions, err := cls.Predict(test)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -1,9 +1,10 @@
|
||||
package knn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
@ -17,7 +18,8 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
@ -47,7 +49,8 @@ func TestKnnClassifierWithOptimisations(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.AllowOptimisations = true
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("When predicting the label for our first vector", func() {
|
||||
@ -75,7 +78,8 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
@ -89,7 +93,16 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
|
||||
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
cls.Fit(trainingData)
|
||||
predictions := cls.Predict(testingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierImplementsClassifier(t *testing.T) {
|
||||
cls := NewKnnClassifier("euclidean", 2)
|
||||
var c base.Classifier = cls
|
||||
if len(c.String()) < 1 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user