mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
knn: tests now passing
This commit is contained in:
parent
72c2005e70
commit
43f04021af
@ -40,8 +40,20 @@ func (g *GoLearnError) attachFormattedStack() {
|
||||
stackFrames := strings.Split(stackString, "\n")
|
||||
|
||||
stackFmt := make([]string, 0)
|
||||
for i := 3; i < len(stackFrames); i++ {
|
||||
for i := 0; i < len(stackFrames); i++ {
|
||||
if strings.Contains(stackFrames[i], "golearn") {
|
||||
if strings.Contains(stackFrames[i], "golearn/base/error.go") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(stackFrames[i], "base.WrapError") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(stackFrames[i], "base.DescribeError") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(stackFrames[i], "golearn/base.(*GoLearnError).attachFormattedStack") {
|
||||
continue
|
||||
}
|
||||
stackFmt = append(stackFmt, stackFrames[i])
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func DeserializeInstancesFromTarReader(tr *FunctionalTarReader, prefix string) (
|
||||
if err != nil {
|
||||
return nil, DescribeError("Class Attribute deserialization error", err)
|
||||
}
|
||||
attrBytes, err = tr.GetNamedFile("ATTRS")
|
||||
attrBytes, err = tr.GetNamedFile(p("ATTRS"))
|
||||
if err != nil {
|
||||
return nil, DescribeError("Unable to read ATTRS", err)
|
||||
}
|
||||
|
92
knn/knn.go
92
knn/knn.go
@ -326,6 +326,96 @@ func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int,
|
||||
return maxClass
|
||||
}
|
||||
|
||||
// GetMetadata returns required serialization information for this classifier
|
||||
func (KNN *KNNClassifier) GetMetadata() base.ClassifierMetadataV1 {
|
||||
|
||||
classifierParams := make(map[string]interface{})
|
||||
classifierParams["distance_func"] = KNN.DistanceFunc
|
||||
classifierParams["algorithm"] = KNN.Algorithm
|
||||
classifierParams["neighbours"] = KNN.NearestNeighbours
|
||||
classifierParams["weighted"] = KNN.Weighted
|
||||
classifierParams["allow_optimizations"] = KNN.AllowOptimisations
|
||||
|
||||
return base.ClassifierMetadataV1{
|
||||
FormatVersion: 1,
|
||||
ClassifierName: "KNN",
|
||||
ClassifierVersion: "1.0",
|
||||
ClassifierMetadata: classifierParams,
|
||||
}
|
||||
}
|
||||
|
||||
// Save outputs a given KNN classifier.
|
||||
func (KNN *KNNClassifier) Save(filePath string) error {
|
||||
writer, err := base.CreateSerializedClassifierStub(filePath, KNN.GetMetadata())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("writer: %v", writer)
|
||||
return KNN.SaveWithPrefix(writer, "")
|
||||
}
|
||||
|
||||
// SaveWithPrefix outputs KNN as part of another file.
|
||||
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||
fmt.Printf("writer: %v", writer)
|
||||
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Load reloads a given KNN classifier when it's the only thing in the output file.
|
||||
func (KNN *KNNClassifier) Load(filePath string) error {
|
||||
reader, err := base.ReadSerializedClassifierStub(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return KNN.LoadWithPrefix(reader, "")
|
||||
}
|
||||
|
||||
// LoadWithPrefix reloads a given KNN classifier when it's part of another file.
|
||||
func (KNN *KNNClassifier) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
|
||||
clsMetadata, err := reader.ReadMetadataAtPrefix(prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if clsMetadata.ClassifierName != "KNN" {
|
||||
return fmt.Errorf("This file doesn't contain a KNN classifier")
|
||||
}
|
||||
if clsMetadata.ClassifierVersion != "1.0" {
|
||||
return fmt.Errorf("Can't understand this file format")
|
||||
}
|
||||
|
||||
metadata := clsMetadata.ClassifierMetadata
|
||||
KNN.DistanceFunc = metadata["distance_func"].(string)
|
||||
KNN.Algorithm = metadata["algorithm"].(string)
|
||||
//KNN.NearestNeighbours = metadata["neighbours"].(int)
|
||||
KNN.Weighted = metadata["weighted"].(bool)
|
||||
KNN.AllowOptimisations = metadata["allow_optimizations"].(bool)
|
||||
|
||||
// 101 on why JSON is a bad serialization format
|
||||
floatNeighbours := metadata["neighbours"].(float64)
|
||||
KNN.NearestNeighbours = int(floatNeighbours)
|
||||
|
||||
KNN.TrainingData, err = reader.GetInstancesForKey(reader.Prefix(prefix, "TrainingInstances"))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ReloadKNNClassifier reloads a KNNClassifier when it's the only thing in an output file.
|
||||
func ReloadKNNClassifier(filePath string) (*KNNClassifier, error) {
|
||||
stub := &KNNClassifier{}
|
||||
err := stub.Load(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stub, nil
|
||||
}
|
||||
|
||||
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
||||
type KNNRegressor struct {
|
||||
base.BaseEstimator
|
||||
@ -384,4 +474,4 @@ func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
||||
|
||||
average := sum / float64(K)
|
||||
return average
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
@ -38,6 +39,36 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnSaveAndReload(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||
cls.AllowOptimisations = false
|
||||
cls.Fit(trainingData)
|
||||
predictions, err := cls.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(predictions, ShouldNotEqual, nil)
|
||||
|
||||
Convey("So saving the classifier should work...", func(){
|
||||
err := cls.Save("temp.cls")
|
||||
So(err, ShouldBeNil)
|
||||
Convey("So loading the classifier should work...", func(){
|
||||
clsR, err := ReloadKNNClassifier("temp.cls")
|
||||
So(err, ShouldBeNil)
|
||||
So(cls.String(), ShouldEqual, clsR.String())
|
||||
predictionsR, err := clsR.Predict(testingData)
|
||||
So(err, ShouldBeNil)
|
||||
So(fmt.Sprintf("%v", predictionsR), ShouldEqual, fmt.Sprintf("%v", predictions))
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
||||
Convey("Given labels, a classifier and data", t, func() {
|
||||
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||
|
Loading…
x
Reference in New Issue
Block a user