diff --git a/README.md b/README.md index 22339b9..c583ef3 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,10 @@ func main() { cls.Fit(trainData) //Calculates the Euclidean distance and returns the most popular label - predictions := cls.Predict(testData) - fmt.Println(predictions) + predictions, err := cls.Predict(testData) + if err != nil { + panic(err) + } // Prints precision/recall metrics confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions) diff --git a/examples/knnclassifier/knnclassifier_iris.go b/examples/knnclassifier/knnclassifier_iris.go index 726d242..592e3ee 100644 --- a/examples/knnclassifier/knnclassifier_iris.go +++ b/examples/knnclassifier/knnclassifier_iris.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/knn" @@ -21,7 +22,10 @@ func main() { cls.Fit(trainData) //Calculates the Euclidean distance and returns the most popular label - predictions := cls.Predict(testData) + predictions, err := cls.Predict(testData) + if err != nil { + panic(err) + } fmt.Println(predictions) // Prints precision/recall metrics diff --git a/knn/knn.go b/knn/knn.go index 1770c0d..bbbd43d 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -4,7 +4,9 @@ package knn import ( + "errors" "fmt" + "github.com/gonum/matrix" "github.com/gonum/matrix/mat64" "github.com/sjwhitworth/golearn/base" @@ -34,8 +36,9 @@ func NewKnnClassifier(distfunc string, neighbours int) *KNNClassifier { } // Fit stores the training data for later -func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) { +func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) error { KNN.TrainingData = trainingData + return nil } func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool { @@ -89,7 +92,7 @@ func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool { } // Predict returns a classification for the vector, based on a vector input, using the KNN algorithm. -func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { +func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) { // Check what distance function we are using var distanceFunc pairwise.PairwiseDistanceFunc switch KNN.DistanceFunc { @@ -98,20 +101,20 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { case "manhattan": distanceFunc = pairwise.NewManhattan() default: - panic("unsupported distance function") + return nil, errors.New("unsupported distance function") } // Check Compatibility allAttrs := base.CheckCompatible(what, KNN.TrainingData) if allAttrs == nil { // Don't have the same Attributes - return nil + return nil, errors.New("attributes not compatible") } // Use optimised version if permitted if KNN.AllowOptimisations { if KNN.DistanceFunc == "euclidean" { if KNN.canUseOptimisations(what) { - return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)) + return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil } } } @@ -189,7 +192,11 @@ func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid { }) - return ret + return ret, nil +} + +func (KNN *KNNClassifier) String() string { + return fmt.Sprintf("KNNClassifier(%s, %d)", KNN.DistanceFunc, KNN.NearestNeighbours) } func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string { diff --git a/knn/knn_bench_test.go b/knn/knn_bench_test.go index 753602d..4d49631 100644 --- a/knn/knn_bench_test.go +++ b/knn/knn_bench_test.go @@ -2,9 +2,10 @@ package knn import ( "fmt" + "testing" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" - "testing" ) func readMnist() (*base.DenseInstances, *base.DenseInstances) { @@ -45,7 +46,10 @@ func BenchmarkKNNWithOpts(b *testing.B) { cls := NewKnnClassifier("euclidean", 1) cls.AllowOptimisations = true cls.Fit(train) - predictions := cls.Predict(test) + predictions, err := cls.Predict(test) + if err != nil { + b.Error(err) + } c, err := evaluation.GetConfusionMatrix(test, predictions) if err != nil { panic(err) @@ -60,7 +64,10 @@ func BenchmarkKNNWithNoOpts(b *testing.B) { cls := NewKnnClassifier("euclidean", 1) cls.AllowOptimisations = false cls.Fit(train) - predictions := cls.Predict(test) + predictions, err := cls.Predict(test) + if err != nil { + b.Error(err) + } c, err := evaluation.GetConfusionMatrix(test, predictions) if err != nil { panic(err) diff --git a/knn/knn_test.go b/knn/knn_test.go index 0ba6570..7b72453 100644 --- a/knn/knn_test.go +++ b/knn/knn_test.go @@ -1,9 +1,10 @@ package knn import ( + "testing" + "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" - "testing" ) func TestKnnClassifierWithoutOptimisations(t *testing.T) { @@ -17,7 +18,8 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) { cls := NewKnnClassifier("euclidean", 2) cls.AllowOptimisations = false cls.Fit(trainingData) - predictions := cls.Predict(testingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) So(predictions, ShouldNotEqual, nil) Convey("When predicting the label for our first vector", func() { @@ -47,7 +49,8 @@ func TestKnnClassifierWithOptimisations(t *testing.T) { cls := NewKnnClassifier("euclidean", 2) cls.AllowOptimisations = true cls.Fit(trainingData) - predictions := cls.Predict(testingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) So(predictions, ShouldNotEqual, nil) Convey("When predicting the label for our first vector", func() { @@ -75,7 +78,8 @@ func TestKnnClassifierWithTemplatedInstances1(t *testing.T) { cls := NewKnnClassifier("euclidean", 2) cls.Fit(trainingData) - predictions := cls.Predict(testingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) So(predictions, ShouldNotBeNil) }) } @@ -89,7 +93,16 @@ func TestKnnClassifierWithTemplatedInstances1Subset(t *testing.T) { cls := NewKnnClassifier("euclidean", 2) cls.Fit(trainingData) - predictions := cls.Predict(testingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) So(predictions, ShouldNotBeNil) }) } + +func TestKnnClassifierImplementsClassifier(t *testing.T) { + cls := NewKnnClassifier("euclidean", 2) + var c base.Classifier = cls + if len(c.String()) < 1 { + t.Fail() + } +}