mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
78 lines
1.7 KiB
Go
78 lines
1.7 KiB
Go
package knn
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
|
)
|
|
|
|
func readMnist() (*base.DenseInstances, *base.DenseInstances) {
|
|
// Create the class Attribute
|
|
classAttrs := make(map[int]base.Attribute)
|
|
classAttrs[0] = base.NewCategoricalAttribute()
|
|
classAttrs[0].SetName("label")
|
|
// Setup the class Attribute to be in its own group
|
|
classAttrGroups := make(map[string]string)
|
|
classAttrGroups["label"] = "ClassGroup"
|
|
// The rest can go in a default group
|
|
attrGroups := make(map[string]string)
|
|
|
|
inst1, err := base.ParseCSVToInstancesWithAttributeGroups(
|
|
"../examples/datasets/mnist_train.csv",
|
|
attrGroups,
|
|
classAttrGroups,
|
|
classAttrs,
|
|
true,
|
|
)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
inst2, err := base.ParseCSVToTemplatedInstances(
|
|
"../examples/datasets/mnist_test.csv",
|
|
true,
|
|
inst1,
|
|
)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return inst1, inst2
|
|
}
|
|
|
|
func BenchmarkKNNWithOpts(b *testing.B) {
|
|
// Load
|
|
train, test := readMnist()
|
|
cls := NewKnnClassifier("euclidean", "linear", 1)
|
|
cls.AllowOptimisations = true
|
|
cls.Fit(train)
|
|
predictions, err := cls.Predict(test)
|
|
if err != nil {
|
|
b.Error(err)
|
|
}
|
|
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println(evaluation.GetSummary(c))
|
|
fmt.Println(evaluation.GetAccuracy(c))
|
|
}
|
|
|
|
func BenchmarkKNNWithNoOpts(b *testing.B) {
|
|
// Load
|
|
train, test := readMnist()
|
|
cls := NewKnnClassifier("euclidean", "linear", 1)
|
|
cls.AllowOptimisations = false
|
|
cls.Fit(train)
|
|
predictions, err := cls.Predict(test)
|
|
if err != nil {
|
|
b.Error(err)
|
|
}
|
|
c, err := evaluation.GetConfusionMatrix(test, predictions)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println(evaluation.GetSummary(c))
|
|
fmt.Println(evaluation.GetAccuracy(c))
|
|
}
|