1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00
golearn/knn/knn_opt_euclidean.go
Thomas Boudalier e219301900 Adapt for go 1.12 - Issue #225
Use now mandatory way to expose C structs with CGO
2019-03-22 10:57:42 +01:00

87 lines
2.3 KiB
Go

package knn
// #include "knn.h"
import "C"
import (
"github.com/sjwhitworth/golearn/base"
"sort"
"unsafe"
)
type dist C.struct_dist
type distanceRecs []C.struct_dist
func (d distanceRecs) Len() int { return len(d) }
func (d distanceRecs) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
func (d distanceRecs) Less(i, j int) bool { return d[i].dist < d[j].dist }
func (KNN *KNNClassifier) optimisedEuclideanPredict(d *base.DenseInstances) base.FixedDataGrid {
// Create return vector
ret := base.GeneratePredictionVector(d)
// Type-assert training data
tr := KNN.TrainingData.(*base.DenseInstances)
// Enumeration of AttributeGroups
agPos := make(map[string]int)
agTrain := tr.AllAttributeGroups()
agPred := d.AllAttributeGroups()
classAttrs := tr.AllClassAttributes()
counter := 0
for ag := range agTrain {
// Detect whether the AttributeGroup has any classes in it
attrs := agTrain[ag].Attributes()
//matched := false
if len(base.AttributeIntersect(classAttrs, attrs)) == 0 {
agPos[ag] = counter
}
counter++
}
// Pointers to the start of each prediction row
rowPointers := make([]*C.double, len(agPred))
trainPointers := make([]*C.double, len(agPred))
rowSizes := make([]int, len(agPred))
for ag := range agPred {
if ap, ok := agPos[ag]; ok {
rowPointers[ap] = (*C.double)(unsafe.Pointer(&(agPred[ag].Storage()[0])))
trainPointers[ap] = (*C.double)(unsafe.Pointer(&(agTrain[ag].Storage()[0])))
rowSizes[ap] = agPred[ag].RowSizeInBytes() / 8
}
}
_, predRows := d.Size()
_, trainRows := tr.Size()
// Crete the distance vector
distanceVec := distanceRecs(make([]C.struct_dist, trainRows))
// Additional datastructures
voteVec := make([]int, KNN.NearestNeighbours)
maxMap := make(map[string]int)
for row := 0; row < predRows; row++ {
for i := 0; i < trainRows; i++ {
distanceVec[i].dist = 0
}
for ag := range agPred {
if ap, ok := agPos[ag]; ok {
C.euclidean_distance(
&(distanceVec[0]),
C.int(trainRows),
C.int(len(agPred[ag].Attributes())),
C.int(row),
trainPointers[ap],
rowPointers[ap],
)
}
}
sort.Sort(distanceVec)
votes := distanceVec[:KNN.NearestNeighbours]
for i, v := range votes {
voteVec[i] = int(v.p)
}
maxClass := KNN.vote(maxMap, voteVec)
base.SetClass(ret, row, maxClass)
}
return ret
}