From 43f04021af2314939000b5eedd05cc0be116c766 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sat, 9 Sep 2017 20:07:56 +0100 Subject: [PATCH] knn: tests now passing --- base/error.go | 14 +++++- base/serialize_instances.go | 2 +- knn/knn.go | 92 ++++++++++++++++++++++++++++++++++++- knn/knn_test.go | 31 +++++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/base/error.go b/base/error.go index b3fe559..a70c49a 100644 --- a/base/error.go +++ b/base/error.go @@ -40,8 +40,20 @@ func (g *GoLearnError) attachFormattedStack() { stackFrames := strings.Split(stackString, "\n") stackFmt := make([]string, 0) - for i := 3; i < len(stackFrames); i++ { + for i := 0; i < len(stackFrames); i++ { if strings.Contains(stackFrames[i], "golearn") { + if strings.Contains(stackFrames[i], "golearn/base/error.go") { + continue + } + if strings.Contains(stackFrames[i], "base.WrapError") { + continue + } + if strings.Contains(stackFrames[i], "base.DescribeError") { + continue + } + if strings.Contains(stackFrames[i], "golearn/base.(*GoLearnError).attachFormattedStack") { + continue + } stackFmt = append(stackFmt, stackFrames[i]) } } diff --git a/base/serialize_instances.go b/base/serialize_instances.go index dca440e..f0f6e47 100644 --- a/base/serialize_instances.go +++ b/base/serialize_instances.go @@ -114,7 +114,7 @@ func DeserializeInstancesFromTarReader(tr *FunctionalTarReader, prefix string) ( if err != nil { return nil, DescribeError("Class Attribute deserialization error", err) } - attrBytes, err = tr.GetNamedFile("ATTRS") + attrBytes, err = tr.GetNamedFile(p("ATTRS")) if err != nil { return nil, DescribeError("Unable to read ATTRS", err) } diff --git a/knn/knn.go b/knn/knn.go index f41a665..268232b 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -326,6 +326,96 @@ func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int, return maxClass } +// GetMetadata returns required serialization information for this classifier +func (KNN *KNNClassifier) GetMetadata() base.ClassifierMetadataV1 { + + classifierParams := make(map[string]interface{}) + classifierParams["distance_func"] = KNN.DistanceFunc + classifierParams["algorithm"] = KNN.Algorithm + classifierParams["neighbours"] = KNN.NearestNeighbours + classifierParams["weighted"] = KNN.Weighted + classifierParams["allow_optimizations"] = KNN.AllowOptimisations + + return base.ClassifierMetadataV1{ + FormatVersion: 1, + ClassifierName: "KNN", + ClassifierVersion: "1.0", + ClassifierMetadata: classifierParams, + } +} + +// Save outputs a given KNN classifier. +func (KNN *KNNClassifier) Save(filePath string) error { + writer, err := base.CreateSerializedClassifierStub(filePath, KNN.GetMetadata()) + if err != nil { + return err + } + fmt.Printf("writer: %v", writer) + return KNN.SaveWithPrefix(writer, "") +} + +// SaveWithPrefix outputs KNN as part of another file. +func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error { + fmt.Printf("writer: %v", writer) + err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true) + if err != nil { + return err + } + err = writer.Close() + return err +} + +// Load reloads a given KNN classifier when it's the only thing in the output file. +func (KNN *KNNClassifier) Load(filePath string) error { + reader, err := base.ReadSerializedClassifierStub(filePath) + if err != nil { + return err + } + + return KNN.LoadWithPrefix(reader, "") +} + +// LoadWithPrefix reloads a given KNN classifier when it's part of another file. +func (KNN *KNNClassifier) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { + + clsMetadata, err := reader.ReadMetadataAtPrefix(prefix) + if err != nil { + return err + } + + if clsMetadata.ClassifierName != "KNN" { + return fmt.Errorf("This file doesn't contain a KNN classifier") + } + if clsMetadata.ClassifierVersion != "1.0" { + return fmt.Errorf("Can't understand this file format") + } + + metadata := clsMetadata.ClassifierMetadata + KNN.DistanceFunc = metadata["distance_func"].(string) + KNN.Algorithm = metadata["algorithm"].(string) + //KNN.NearestNeighbours = metadata["neighbours"].(int) + KNN.Weighted = metadata["weighted"].(bool) + KNN.AllowOptimisations = metadata["allow_optimizations"].(bool) + + // 101 on why JSON is a bad serialization format + floatNeighbours := metadata["neighbours"].(float64) + KNN.NearestNeighbours = int(floatNeighbours) + + KNN.TrainingData, err = reader.GetInstancesForKey(reader.Prefix(prefix, "TrainingInstances")) + + return err +} + +// ReloadKNNClassifier reloads a KNNClassifier when it's the only thing in an output file. +func ReloadKNNClassifier(filePath string) (*KNNClassifier, error) { + stub := &KNNClassifier{} + err := stub.Load(filePath) + if err != nil { + return nil, err + } + return stub, nil +} + // A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name. type KNNRegressor struct { base.BaseEstimator @@ -384,4 +474,4 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 { average := sum / float64(K) return average -} +} \ No newline at end of file diff --git a/knn/knn_test.go b/knn/knn_test.go index 6915c85..74691d3 100644 --- a/knn/knn_test.go +++ b/knn/knn_test.go @@ -5,6 +5,7 @@ import ( "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" + "fmt" ) func TestKnnClassifierWithoutOptimisations(t *testing.T) { @@ -38,6 +39,36 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) { }) } +func TestKnnSaveAndReload(t *testing.T) { + Convey("Given labels, a classifier and data", t, func() { + trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false) + So(err, ShouldBeNil) + + testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false) + So(err, ShouldBeNil) + + cls := NewKnnClassifier("euclidean", "linear", 2) + cls.AllowOptimisations = false + cls.Fit(trainingData) + predictions, err := cls.Predict(testingData) + So(err, ShouldBeNil) + So(predictions, ShouldNotEqual, nil) + + Convey("So saving the classifier should work...", func(){ + err := cls.Save("temp.cls") + So(err, ShouldBeNil) + Convey("So loading the classifier should work...", func(){ + clsR, err := ReloadKNNClassifier("temp.cls") + So(err, ShouldBeNil) + So(cls.String(), ShouldEqual, clsR.String()) + predictionsR, err := clsR.Predict(testingData) + So(err, ShouldBeNil) + So(fmt.Sprintf("%v", predictionsR), ShouldEqual, fmt.Sprintf("%v", predictions)) + }) + }) + }) +} + func TestKnnClassifierWithOptimisations(t *testing.T) { Convey("Given labels, a classifier and data", t, func() { trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)