1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

knn: tests now passing

This commit is contained in:
Richard Townsend 2017-09-09 20:07:56 +01:00
parent 72c2005e70
commit 43f04021af
4 changed files with 136 additions and 3 deletions

View File

@ -40,8 +40,20 @@ func (g *GoLearnError) attachFormattedStack() {
stackFrames := strings.Split(stackString, "\n")
stackFmt := make([]string, 0)
for i := 3; i < len(stackFrames); i++ {
for i := 0; i < len(stackFrames); i++ {
if strings.Contains(stackFrames[i], "golearn") {
if strings.Contains(stackFrames[i], "golearn/base/error.go") {
continue
}
if strings.Contains(stackFrames[i], "base.WrapError") {
continue
}
if strings.Contains(stackFrames[i], "base.DescribeError") {
continue
}
if strings.Contains(stackFrames[i], "golearn/base.(*GoLearnError).attachFormattedStack") {
continue
}
stackFmt = append(stackFmt, stackFrames[i])
}
}

View File

@ -114,7 +114,7 @@ func DeserializeInstancesFromTarReader(tr *FunctionalTarReader, prefix string) (
if err != nil {
return nil, DescribeError("Class Attribute deserialization error", err)
}
attrBytes, err = tr.GetNamedFile("ATTRS")
attrBytes, err = tr.GetNamedFile(p("ATTRS"))
if err != nil {
return nil, DescribeError("Unable to read ATTRS", err)
}

View File

@ -326,6 +326,96 @@ func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int,
return maxClass
}
// GetMetadata returns required serialization information for this classifier
func (KNN *KNNClassifier) GetMetadata() base.ClassifierMetadataV1 {
classifierParams := make(map[string]interface{})
classifierParams["distance_func"] = KNN.DistanceFunc
classifierParams["algorithm"] = KNN.Algorithm
classifierParams["neighbours"] = KNN.NearestNeighbours
classifierParams["weighted"] = KNN.Weighted
classifierParams["allow_optimizations"] = KNN.AllowOptimisations
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierVersion: "1.0",
ClassifierMetadata: classifierParams,
}
}
// Save outputs a given KNN classifier.
func (KNN *KNNClassifier) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, KNN.GetMetadata())
if err != nil {
return err
}
fmt.Printf("writer: %v", writer)
return KNN.SaveWithPrefix(writer, "")
}
// SaveWithPrefix outputs KNN as part of another file.
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
fmt.Printf("writer: %v", writer)
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
if err != nil {
return err
}
err = writer.Close()
return err
}
// Load reloads a given KNN classifier when it's the only thing in the output file.
func (KNN *KNNClassifier) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
return KNN.LoadWithPrefix(reader, "")
}
// LoadWithPrefix reloads a given KNN classifier when it's part of another file.
func (KNN *KNNClassifier) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
clsMetadata, err := reader.ReadMetadataAtPrefix(prefix)
if err != nil {
return err
}
if clsMetadata.ClassifierName != "KNN" {
return fmt.Errorf("This file doesn't contain a KNN classifier")
}
if clsMetadata.ClassifierVersion != "1.0" {
return fmt.Errorf("Can't understand this file format")
}
metadata := clsMetadata.ClassifierMetadata
KNN.DistanceFunc = metadata["distance_func"].(string)
KNN.Algorithm = metadata["algorithm"].(string)
//KNN.NearestNeighbours = metadata["neighbours"].(int)
KNN.Weighted = metadata["weighted"].(bool)
KNN.AllowOptimisations = metadata["allow_optimizations"].(bool)
// 101 on why JSON is a bad serialization format
floatNeighbours := metadata["neighbours"].(float64)
KNN.NearestNeighbours = int(floatNeighbours)
KNN.TrainingData, err = reader.GetInstancesForKey(reader.Prefix(prefix, "TrainingInstances"))
return err
}
// ReloadKNNClassifier reloads a KNNClassifier when it's the only thing in an output file.
func ReloadKNNClassifier(filePath string) (*KNNClassifier, error) {
stub := &KNNClassifier{}
err := stub.Load(filePath)
if err != nil {
return nil, err
}
return stub, nil
}
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
type KNNRegressor struct {
base.BaseEstimator
@ -384,4 +474,4 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
average := sum / float64(K)
return average
}
}

View File

@ -5,6 +5,7 @@ import (
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"fmt"
)
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
@ -38,6 +39,36 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
})
}
func TestKnnSaveAndReload(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", "linear", 2)
cls.AllowOptimisations = false
cls.Fit(trainingData)
predictions, err := cls.Predict(testingData)
So(err, ShouldBeNil)
So(predictions, ShouldNotEqual, nil)
Convey("So saving the classifier should work...", func(){
err := cls.Save("temp.cls")
So(err, ShouldBeNil)
Convey("So loading the classifier should work...", func(){
clsR, err := ReloadKNNClassifier("temp.cls")
So(err, ShouldBeNil)
So(cls.String(), ShouldEqual, clsR.String())
predictionsR, err := clsR.Predict(testingData)
So(err, ShouldBeNil)
So(fmt.Sprintf("%v", predictionsR), ShouldEqual, fmt.Sprintf("%v", predictions))
})
})
})
}
func TestKnnClassifierWithOptimisations(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)