mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00
knn: tests now passing
This commit is contained in:
parent
72c2005e70
commit
43f04021af
@ -40,8 +40,20 @@ func (g *GoLearnError) attachFormattedStack() {
|
|||||||
stackFrames := strings.Split(stackString, "\n")
|
stackFrames := strings.Split(stackString, "\n")
|
||||||
|
|
||||||
stackFmt := make([]string, 0)
|
stackFmt := make([]string, 0)
|
||||||
for i := 3; i < len(stackFrames); i++ {
|
for i := 0; i < len(stackFrames); i++ {
|
||||||
if strings.Contains(stackFrames[i], "golearn") {
|
if strings.Contains(stackFrames[i], "golearn") {
|
||||||
|
if strings.Contains(stackFrames[i], "golearn/base/error.go") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(stackFrames[i], "base.WrapError") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(stackFrames[i], "base.DescribeError") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(stackFrames[i], "golearn/base.(*GoLearnError).attachFormattedStack") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
stackFmt = append(stackFmt, stackFrames[i])
|
stackFmt = append(stackFmt, stackFrames[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func DeserializeInstancesFromTarReader(tr *FunctionalTarReader, prefix string) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, DescribeError("Class Attribute deserialization error", err)
|
return nil, DescribeError("Class Attribute deserialization error", err)
|
||||||
}
|
}
|
||||||
attrBytes, err = tr.GetNamedFile("ATTRS")
|
attrBytes, err = tr.GetNamedFile(p("ATTRS"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, DescribeError("Unable to read ATTRS", err)
|
return nil, DescribeError("Unable to read ATTRS", err)
|
||||||
}
|
}
|
||||||
|
90
knn/knn.go
90
knn/knn.go
@ -326,6 +326,96 @@ func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int,
|
|||||||
return maxClass
|
return maxClass
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMetadata returns required serialization information for this classifier
|
||||||
|
func (KNN *KNNClassifier) GetMetadata() base.ClassifierMetadataV1 {
|
||||||
|
|
||||||
|
classifierParams := make(map[string]interface{})
|
||||||
|
classifierParams["distance_func"] = KNN.DistanceFunc
|
||||||
|
classifierParams["algorithm"] = KNN.Algorithm
|
||||||
|
classifierParams["neighbours"] = KNN.NearestNeighbours
|
||||||
|
classifierParams["weighted"] = KNN.Weighted
|
||||||
|
classifierParams["allow_optimizations"] = KNN.AllowOptimisations
|
||||||
|
|
||||||
|
return base.ClassifierMetadataV1{
|
||||||
|
FormatVersion: 1,
|
||||||
|
ClassifierName: "KNN",
|
||||||
|
ClassifierVersion: "1.0",
|
||||||
|
ClassifierMetadata: classifierParams,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save outputs a given KNN classifier.
|
||||||
|
func (KNN *KNNClassifier) Save(filePath string) error {
|
||||||
|
writer, err := base.CreateSerializedClassifierStub(filePath, KNN.GetMetadata())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("writer: %v", writer)
|
||||||
|
return KNN.SaveWithPrefix(writer, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveWithPrefix outputs KNN as part of another file.
|
||||||
|
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
||||||
|
fmt.Printf("writer: %v", writer)
|
||||||
|
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = writer.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reloads a given KNN classifier when it's the only thing in the output file.
|
||||||
|
func (KNN *KNNClassifier) Load(filePath string) error {
|
||||||
|
reader, err := base.ReadSerializedClassifierStub(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return KNN.LoadWithPrefix(reader, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWithPrefix reloads a given KNN classifier when it's part of another file.
|
||||||
|
func (KNN *KNNClassifier) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||||
|
|
||||||
|
clsMetadata, err := reader.ReadMetadataAtPrefix(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if clsMetadata.ClassifierName != "KNN" {
|
||||||
|
return fmt.Errorf("This file doesn't contain a KNN classifier")
|
||||||
|
}
|
||||||
|
if clsMetadata.ClassifierVersion != "1.0" {
|
||||||
|
return fmt.Errorf("Can't understand this file format")
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := clsMetadata.ClassifierMetadata
|
||||||
|
KNN.DistanceFunc = metadata["distance_func"].(string)
|
||||||
|
KNN.Algorithm = metadata["algorithm"].(string)
|
||||||
|
//KNN.NearestNeighbours = metadata["neighbours"].(int)
|
||||||
|
KNN.Weighted = metadata["weighted"].(bool)
|
||||||
|
KNN.AllowOptimisations = metadata["allow_optimizations"].(bool)
|
||||||
|
|
||||||
|
// 101 on why JSON is a bad serialization format
|
||||||
|
floatNeighbours := metadata["neighbours"].(float64)
|
||||||
|
KNN.NearestNeighbours = int(floatNeighbours)
|
||||||
|
|
||||||
|
KNN.TrainingData, err = reader.GetInstancesForKey(reader.Prefix(prefix, "TrainingInstances"))
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadKNNClassifier reloads a KNNClassifier when it's the only thing in an output file.
|
||||||
|
func ReloadKNNClassifier(filePath string) (*KNNClassifier, error) {
|
||||||
|
stub := &KNNClassifier{}
|
||||||
|
err := stub.Load(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stub, nil
|
||||||
|
}
|
||||||
|
|
||||||
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
||||||
type KNNRegressor struct {
|
type KNNRegressor struct {
|
||||||
base.BaseEstimator
|
base.BaseEstimator
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
||||||
@ -38,6 +39,36 @@ func TestKnnClassifierWithoutOptimisations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKnnSaveAndReload(t *testing.T) {
|
||||||
|
Convey("Given labels, a classifier and data", t, func() {
|
||||||
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
testingData, err := base.ParseCSVToInstances("knn_test_1.csv", false)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
|
cls := NewKnnClassifier("euclidean", "linear", 2)
|
||||||
|
cls.AllowOptimisations = false
|
||||||
|
cls.Fit(trainingData)
|
||||||
|
predictions, err := cls.Predict(testingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(predictions, ShouldNotEqual, nil)
|
||||||
|
|
||||||
|
Convey("So saving the classifier should work...", func(){
|
||||||
|
err := cls.Save("temp.cls")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
Convey("So loading the classifier should work...", func(){
|
||||||
|
clsR, err := ReloadKNNClassifier("temp.cls")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(cls.String(), ShouldEqual, clsR.String())
|
||||||
|
predictionsR, err := clsR.Predict(testingData)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(fmt.Sprintf("%v", predictionsR), ShouldEqual, fmt.Sprintf("%v", predictions))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
func TestKnnClassifierWithOptimisations(t *testing.T) {
|
||||||
Convey("Given labels, a classifier and data", t, func() {
|
Convey("Given labels, a classifier and data", t, func() {
|
||||||
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
trainingData, err := base.ParseCSVToInstances("knn_train_1.csv", false)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user