1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Return errors allowing KNNClassifier to implement Classifier

This commit is contained in:
Tim Lebel 2016-10-10 19:45:20 -07:00
parent 72768fa9fc
commit 1d77ccdec6
5 changed files with 50 additions and 17 deletions

View File

@ -55,8 +55,10 @@ func main() {
cls.Fit(trainData)
//Calculates the Euclidean distance and returns the most popular label
predictions := cls.Predict(testData)
fmt.Println(predictions)
predictions, err := cls.Predict(testData)
if err != nil {
panic(err)
}
// Prints precision/recall metrics
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)

View File

@ -2,6 +2,7 @@ package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/knn"
@ -21,7 +22,10 @@ func main() {
cls.Fit(trainData)
//Calculates the Euclidean distance and returns the most popular label
predictions := cls.Predict(testData)
predictions, err := cls.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println(predictions)
// Prints precision/recall metrics

View File

@ -4,7 +4,9 @@
package knn
import (
"errors"
"fmt"
"github.com/gonum/matrix"
"github.com/gonum/matrix/mat64"
"github.com/sjwhitworth/golearn/base"
@ -34,8 +36,9 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier {
}
// Fit stores the training data for later
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) {
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) error {
KNN.TrainingData = trainingData
return nil
}
func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
@ -89,7 +92,7 @@ func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
}
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
// Check what distance function we are using
var distanceFunc pairwise.PairwiseDistanceFunc
switch KNN.DistanceFunc {
@ -98,20 +101,20 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
case "manhattan":
distanceFunc = pairwise.NewManhattan()
default:
panic("unsupported distance function")
return nil, errors.New("unsupported distance function")
}
// Check Compatibility
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
if allAttrs == nil {
// Don't have the same Attributes
return nil
return nil, errors.New("attributes not compatible")
}
// Use optimised version if permitted
if KNN.AllowOptimisations {
if KNN.DistanceFunc == "euclidean" {
if KNN.canUseOptimisations(what) {
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances))
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
}
}
}
@ -189,7 +192,11 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
})
return ret
return ret, nil
}
func (KNN *KNNClassifier) String() string {
return fmt.Sprintf("KNNClassifier(%s, %d)", KNN.DistanceFunc, KNN.NearestNeighbours)
}
func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {

View File

@ -2,9 +2,10 @@ package knn
import (
"fmt"
"testing"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"testing"
)
func readMnist() (*base.DenseInstances, *base.DenseInstances) {
@ -45,7 +46,10 @@ func BenchmarkKNNWithOpts(b *testing.B) {
cls := NewKnnClassifier("euclidean", 1)
cls.AllowOptimisations = true
cls.Fit(train)
predictions := cls.Predict(test)
predictions, err := cls.Predict(test)
if err != nil {
b.Error(err)
}
c, err := evaluation.GetConfusionMatrix(test, predictions)
if err != nil {
panic(err)
@ -60,7 +64,10 @@ func BenchmarkKNNWithNoOpts(b *testing.B) {
cls := NewKnnClassifier("euclidean", 1)
cls.AllowOptimisations = false
cls.Fit(train)
predictions := cls.Predict(test)
predictions, err := cls.Predict(test)
if err != nil {
b.Error(err)
}
c, err := evaluation.GetConfusionMatrix(test, predictions)
if err != nil {
panic(err)

View File

@ -1,9 +1,10 @@
package knn
import (
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
@ -17,7 +18,8 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
@ -47,7 +49,8 @@ func TestKnnClassifierWithOptimisations(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
cls.AllowOptimisations = true
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
@ -75,7 +78,8 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotBeNil)
})
}
@ -89,7 +93,16 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotBeNil)
})
}
func TestKnnClassifierImplementsClassifier(t *testing.T) {
cls := NewKnnClassifier("euclidean", 2)
var c base.Classifier = cls
if len(c.String()) < 1 {
t.Fail()
}
}