1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/knn/knn_bench_test.go
Richard Townsend 527c6476e1 Optimised version of KNN for Euclidean distances
This patch also:
   * Completes removal of the edf/ package
   * Corrects an erroneous print statement
   * Introduces two new CSV functions
      * ParseCSVToInstancesTemplated makes sure that
        reading a second CSV file maintains strict Attribute
        compatibility with an existing DenseInstances
      * ParseCSVToInstancesWithAttributeGroups gives more control
        over where Attributes end up in memory, important for
        gaining predictable control over the KNN optimisation
      * Decouples BinaryAttributeGroup from FixedAttributeGroup for
        better casting support
2014-09-30 23:10:22 +01:00

71 lines
1.6 KiB
Go

package knn
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"testing"
)
func readMnist() (*base.DenseInstances, *base.DenseInstances) {
// Create the class Attribute
classAttrs := make(map[int]base.Attribute)
classAttrs[0] = base.NewCategoricalAttribute()
classAttrs[0].SetName("label")
// Setup the class Attribute to be in its own group
classAttrGroups := make(map[string]string)
classAttrGroups["label"] = "ClassGroup"
// The rest can go in a default group
attrGroups := make(map[string]string)
inst1, err := base.ParseCSVToInstancesWithAttributeGroups(
"../examples/datasets/mnist_train.csv",
attrGroups,
classAttrGroups,
classAttrs,
true,
)
if err != nil {
panic(err)
}
inst2, err := base.ParseCSVToTemplatedInstances(
"../examples/datasets/mnist_test.csv",
true,
inst1,
)
if err != nil {
panic(err)
}
return inst1, inst2
}
func BenchmarkKNNWithOpts(b *testing.B) {
// Load
train, test := readMnist()
cls := NewKnnClassifier("euclidean", 1)
cls.AllowOptimisations = true
cls.Fit(train)
predictions := cls.Predict(test)
c, err := evaluation.GetConfusionMatrix(test, predictions)
if err != nil {
panic(err)
}
fmt.Println(evaluation.GetSummary(c))
fmt.Println(evaluation.GetAccuracy(c))
}
func BenchmarkKNNWithNoOpts(b *testing.B) {
// Load
train, test := readMnist()
cls := NewKnnClassifier("euclidean", 1)
cls.AllowOptimisations = false
cls.Fit(train)
predictions := cls.Predict(test)
c, err := evaluation.GetConfusionMatrix(test, predictions)
if err != nil {
panic(err)
}
fmt.Println(evaluation.GetSummary(c))
fmt.Println(evaluation.GetAccuracy(c))
}