mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
477 lines
13 KiB
Go
477 lines
13 KiB
Go
// Package knn implements a K Nearest Neighbors object, capable of both classification
|
|
// and regression. It accepts data in the form of a slice of float64s, which are then reshaped
|
|
// into a X by Y matrix.
|
|
package knn
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/gonum/matrix"
|
|
"github.com/gonum/matrix/mat64"
|
|
"github.com/sjwhitworth/golearn/base"
|
|
"github.com/sjwhitworth/golearn/kdtree"
|
|
"github.com/sjwhitworth/golearn/metrics/pairwise"
|
|
"github.com/sjwhitworth/golearn/utilities"
|
|
)
|
|
|
|
// A KNNClassifier consists of a data matrix, associated labels in the same order as the matrix, searching algorithm, and a distance function.
|
|
// The accepted distance functions at this time are 'euclidean', 'manhattan', and 'cosine'.
|
|
// The accepted searching algorithm here are 'linear', and 'kdtree'.
|
|
// Optimisations only occur when things are identically group into identical
|
|
// AttributeGroups, which don't include the class variable, in the same order.
|
|
// Using weighted KNN when Weighted set to be true (default: false).
|
|
type KNNClassifier struct {
|
|
base.BaseEstimator
|
|
TrainingData base.FixedDataGrid
|
|
DistanceFunc string
|
|
Algorithm string
|
|
NearestNeighbours int
|
|
AllowOptimisations bool
|
|
Weighted bool
|
|
}
|
|
|
|
// NewKnnClassifier returns a new classifier
|
|
func NewKnnClassifier(distfunc, algorithm string, neighbours int) *KNNClassifier {
|
|
KNN := KNNClassifier{}
|
|
KNN.DistanceFunc = distfunc
|
|
KNN.Algorithm = algorithm
|
|
KNN.NearestNeighbours = neighbours
|
|
KNN.Weighted = false
|
|
KNN.AllowOptimisations = true
|
|
return &KNN
|
|
}
|
|
|
|
// Fit stores the training data for later
|
|
func (KNN *KNNClassifier) Fit(trainingData base.FixedDataGrid) error {
|
|
KNN.TrainingData = trainingData
|
|
return nil
|
|
}
|
|
|
|
func (KNN *KNNClassifier) canUseOptimisations(what base.FixedDataGrid) bool {
|
|
// Check that the two have exactly the same layout
|
|
if !base.CheckStrictlyCompatible(what, KNN.TrainingData) {
|
|
return false
|
|
}
|
|
// Check that the two are DenseInstances
|
|
whatd, ok1 := what.(*base.DenseInstances)
|
|
_, ok2 := KNN.TrainingData.(*base.DenseInstances)
|
|
if !ok1 || !ok2 {
|
|
return false
|
|
}
|
|
// Check that no Class Attributes are mixed in with the data
|
|
classAttrs := whatd.AllClassAttributes()
|
|
normalAttrs := base.NonClassAttributes(whatd)
|
|
// Retrieve all the AGs
|
|
ags := whatd.AllAttributeGroups()
|
|
classAttrGroups := make([]base.AttributeGroup, 0)
|
|
for agName := range ags {
|
|
ag := ags[agName]
|
|
attrs := ag.Attributes()
|
|
matched := false
|
|
for _, a := range attrs {
|
|
for _, c := range classAttrs {
|
|
if a.Equals(c) {
|
|
matched = true
|
|
}
|
|
}
|
|
}
|
|
if matched {
|
|
classAttrGroups = append(classAttrGroups, ag)
|
|
}
|
|
}
|
|
for _, cag := range classAttrGroups {
|
|
attrs := cag.Attributes()
|
|
common := base.AttributeIntersect(normalAttrs, attrs)
|
|
if len(common) != 0 {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Check that all of the Attributes are numeric
|
|
for _, a := range normalAttrs {
|
|
if _, ok := a.(*base.FloatAttribute); !ok {
|
|
return false
|
|
}
|
|
}
|
|
// If that's fine, return true
|
|
return true
|
|
}
|
|
|
|
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
|
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
|
|
// Check what distance function we are using
|
|
var distanceFunc pairwise.PairwiseDistanceFunc
|
|
switch KNN.DistanceFunc {
|
|
case "euclidean":
|
|
distanceFunc = pairwise.NewEuclidean()
|
|
case "manhattan":
|
|
distanceFunc = pairwise.NewManhattan()
|
|
case "cosine":
|
|
distanceFunc = pairwise.NewCosine()
|
|
default:
|
|
return nil, errors.New("unsupported distance function")
|
|
}
|
|
|
|
// Check what searching algorith, we are using
|
|
if KNN.Algorithm != "linear" && KNN.Algorithm != "kdtree" {
|
|
return nil, errors.New("unsupported searching algorithm")
|
|
}
|
|
|
|
// Check Compatibility
|
|
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
|
|
if allAttrs == nil {
|
|
// Don't have the same Attributes
|
|
return nil, errors.New("attributes not compatible")
|
|
}
|
|
|
|
// Use optimised version if permitted
|
|
if KNN.Algorithm == "linear" && KNN.AllowOptimisations {
|
|
if KNN.DistanceFunc == "euclidean" {
|
|
if KNN.canUseOptimisations(what) {
|
|
return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances)), nil
|
|
}
|
|
}
|
|
}
|
|
fmt.Println("Optimisations are switched off")
|
|
|
|
// Remove the Attributes which aren't numeric
|
|
allNumericAttrs := make([]base.Attribute, 0)
|
|
for _, a := range allAttrs {
|
|
if fAttr, ok := a.(*base.FloatAttribute); ok {
|
|
allNumericAttrs = append(allNumericAttrs, fAttr)
|
|
}
|
|
}
|
|
|
|
// If every Attribute is a FloatAttribute, then we remove the last one
|
|
// because that is the Attribute we are trying to predict.
|
|
if len(allNumericAttrs) == len(allAttrs) {
|
|
allNumericAttrs = allNumericAttrs[:len(allNumericAttrs)-1]
|
|
}
|
|
|
|
// Generate return vector
|
|
ret := base.GeneratePredictionVector(what)
|
|
|
|
// Resolve Attribute specifications for both
|
|
whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
|
|
trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
|
|
|
|
// Reserve storage for most the most similar items
|
|
distances := make(map[int]float64)
|
|
|
|
// Reserve storage for voting map
|
|
maxmapInt := make(map[string]int)
|
|
maxmapFloat := make(map[string]float64)
|
|
|
|
// Reserve storage for row computations
|
|
trainRowBuf := make([]float64, len(allNumericAttrs))
|
|
predRowBuf := make([]float64, len(allNumericAttrs))
|
|
|
|
_, maxRow := what.Size()
|
|
curRow := 0
|
|
|
|
// build kdtree if algorithm is 'kdtree'
|
|
kd := kdtree.New()
|
|
srcRowNoMap := make([]int, 0)
|
|
if KNN.Algorithm == "kdtree" {
|
|
buildData := make([][]float64, 0)
|
|
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
|
oneData := make([]float64, len(allNumericAttrs))
|
|
// Read the float values out
|
|
for i, _ := range allNumericAttrs {
|
|
oneData[i] = base.UnpackBytesToFloat(trainRow[i])
|
|
}
|
|
srcRowNoMap = append(srcRowNoMap, srcRowNo)
|
|
buildData = append(buildData, oneData)
|
|
return true, nil
|
|
})
|
|
|
|
err := kd.Build(buildData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Iterate over all outer rows
|
|
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
|
|
|
|
if (curRow%1) == 0 && curRow > 0 {
|
|
fmt.Printf("KNN: %.2f %% done\r", float64(curRow)*100.0/float64(maxRow))
|
|
}
|
|
curRow++
|
|
|
|
// Read the float values out
|
|
for i, _ := range allNumericAttrs {
|
|
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
|
|
}
|
|
|
|
predMat := utilities.FloatsToMatrix(predRowBuf)
|
|
|
|
switch KNN.Algorithm {
|
|
case "linear":
|
|
// Find the closest match in the training data
|
|
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
|
|
// Read the float values out
|
|
for i, _ := range allNumericAttrs {
|
|
trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
|
|
}
|
|
|
|
// Compute the distance
|
|
trainMat := utilities.FloatsToMatrix(trainRowBuf)
|
|
distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
|
|
return true, nil
|
|
})
|
|
|
|
sorted := utilities.SortIntMap(distances)
|
|
values := sorted[:KNN.NearestNeighbours]
|
|
|
|
length := make([]float64, KNN.NearestNeighbours)
|
|
for k, v := range values {
|
|
length[k] = distances[v]
|
|
}
|
|
|
|
var maxClass string
|
|
if KNN.Weighted {
|
|
maxClass = KNN.weightedVote(maxmapFloat, values, length)
|
|
} else {
|
|
maxClass = KNN.vote(maxmapInt, values)
|
|
}
|
|
base.SetClass(ret, predRowNo, maxClass)
|
|
|
|
case "kdtree":
|
|
// search kdtree
|
|
values, length, err := kd.Search(KNN.NearestNeighbours, distanceFunc, predRowBuf)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// map values to srcRowNo
|
|
for k, v := range values {
|
|
values[k] = srcRowNoMap[v]
|
|
}
|
|
|
|
var maxClass string
|
|
if KNN.Weighted {
|
|
maxClass = KNN.weightedVote(maxmapFloat, values, length)
|
|
} else {
|
|
maxClass = KNN.vote(maxmapInt, values)
|
|
}
|
|
base.SetClass(ret, predRowNo, maxClass)
|
|
}
|
|
|
|
return true, nil
|
|
|
|
})
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (KNN *KNNClassifier) String() string {
|
|
return fmt.Sprintf("KNNClassifier(%s, %d)", KNN.DistanceFunc, KNN.NearestNeighbours)
|
|
}
|
|
|
|
func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
|
|
// Reset maxMap
|
|
for a := range maxmap {
|
|
maxmap[a] = 0
|
|
}
|
|
|
|
// Refresh maxMap
|
|
for _, elem := range values {
|
|
label := base.GetClass(KNN.TrainingData, elem)
|
|
if _, ok := maxmap[label]; ok {
|
|
maxmap[label]++
|
|
} else {
|
|
maxmap[label] = 1
|
|
}
|
|
}
|
|
|
|
// Sort the maxMap
|
|
var maxClass string
|
|
maxVal := -1
|
|
for a := range maxmap {
|
|
if maxmap[a] > maxVal {
|
|
maxVal = maxmap[a]
|
|
maxClass = a
|
|
}
|
|
}
|
|
return maxClass
|
|
}
|
|
|
|
func (KNN *KNNClassifier) weightedVote(maxmap map[string]float64, values []int, length []float64) string {
|
|
// Reset maxMap
|
|
for a := range maxmap {
|
|
maxmap[a] = 0
|
|
}
|
|
|
|
// Refresh maxMap
|
|
for k, elem := range values {
|
|
label := base.GetClass(KNN.TrainingData, elem)
|
|
if _, ok := maxmap[label]; ok {
|
|
maxmap[label] += (1 / length[k])
|
|
} else {
|
|
maxmap[label] = (1 / length[k])
|
|
}
|
|
}
|
|
|
|
// Sort the maxMap
|
|
var maxClass string
|
|
maxVal := -1.0
|
|
for a := range maxmap {
|
|
if maxmap[a] > maxVal {
|
|
maxVal = maxmap[a]
|
|
maxClass = a
|
|
}
|
|
}
|
|
return maxClass
|
|
}
|
|
|
|
// GetMetadata returns required serialization information for this classifier
|
|
func (KNN *KNNClassifier) GetMetadata() base.ClassifierMetadataV1 {
|
|
|
|
classifierParams := make(map[string]interface{})
|
|
classifierParams["distance_func"] = KNN.DistanceFunc
|
|
classifierParams["algorithm"] = KNN.Algorithm
|
|
classifierParams["neighbours"] = KNN.NearestNeighbours
|
|
classifierParams["weighted"] = KNN.Weighted
|
|
classifierParams["allow_optimizations"] = KNN.AllowOptimisations
|
|
|
|
return base.ClassifierMetadataV1{
|
|
FormatVersion: 1,
|
|
ClassifierName: "KNN",
|
|
ClassifierVersion: "1.0",
|
|
ClassifierMetadata: classifierParams,
|
|
}
|
|
}
|
|
|
|
// Save outputs a given KNN classifier.
|
|
func (KNN *KNNClassifier) Save(filePath string) error {
|
|
writer, err := base.CreateSerializedClassifierStub(filePath, KNN.GetMetadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmt.Printf("writer: %v", writer)
|
|
return KNN.SaveWithPrefix(writer, "")
|
|
}
|
|
|
|
// SaveWithPrefix outputs KNN as part of another file.
|
|
func (KNN *KNNClassifier) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
|
|
err := writer.WriteInstancesForKey(writer.Prefix(prefix, "TrainingInstances"), KNN.TrainingData, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = writer.Close()
|
|
return err
|
|
}
|
|
|
|
// Load reloads a given KNN classifier when it's the only thing in the output file.
|
|
func (KNN *KNNClassifier) Load(filePath string) error {
|
|
reader, err := base.ReadSerializedClassifierStub(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return KNN.LoadWithPrefix(reader, "")
|
|
}
|
|
|
|
// LoadWithPrefix reloads a given KNN classifier when it's part of another file.
|
|
func (KNN *KNNClassifier) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
|
|
|
clsMetadata, err := reader.ReadMetadataAtPrefix(prefix)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if clsMetadata.ClassifierName != "KNN" {
|
|
return fmt.Errorf("This file doesn't contain a KNN classifier")
|
|
}
|
|
if clsMetadata.ClassifierVersion != "1.0" {
|
|
return fmt.Errorf("Can't understand this file format")
|
|
}
|
|
|
|
metadata := clsMetadata.ClassifierMetadata
|
|
KNN.DistanceFunc = metadata["distance_func"].(string)
|
|
KNN.Algorithm = metadata["algorithm"].(string)
|
|
//KNN.NearestNeighbours = metadata["neighbours"].(int)
|
|
KNN.Weighted = metadata["weighted"].(bool)
|
|
KNN.AllowOptimisations = metadata["allow_optimizations"].(bool)
|
|
|
|
// 101 on why JSON is a bad serialization format
|
|
floatNeighbours := metadata["neighbours"].(float64)
|
|
KNN.NearestNeighbours = int(floatNeighbours)
|
|
|
|
KNN.TrainingData, err = reader.GetInstancesForKey(reader.Prefix(prefix, "TrainingInstances"))
|
|
|
|
return err
|
|
}
|
|
|
|
// ReloadKNNClassifier reloads a KNNClassifier when it's the only thing in an output file.
|
|
func ReloadKNNClassifier(filePath string) (*KNNClassifier, error) {
|
|
stub := &KNNClassifier{}
|
|
err := stub.Load(filePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return stub, nil
|
|
}
|
|
|
|
// A KNNRegressor consists of a data matrix, associated result variables in the same order as the matrix, and a name.
|
|
type KNNRegressor struct {
|
|
base.BaseEstimator
|
|
Values []float64
|
|
DistanceFunc string
|
|
}
|
|
|
|
// NewKnnRegressor mints a new classifier.
|
|
func NewKnnRegressor(distfunc string) *KNNRegressor {
|
|
KNN := KNNRegressor{}
|
|
KNN.DistanceFunc = distfunc
|
|
return &KNN
|
|
}
|
|
|
|
func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) {
|
|
if rows != len(values) {
|
|
panic(matrix.ErrShape)
|
|
}
|
|
|
|
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
|
KNN.Values = values
|
|
}
|
|
|
|
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
|
// Get the number of rows
|
|
rows, _ := KNN.Data.Dims()
|
|
rownumbers := make(map[int]float64)
|
|
labels := make([]float64, 0)
|
|
|
|
// Check what distance function we are using
|
|
var distanceFunc pairwise.PairwiseDistanceFunc
|
|
switch KNN.DistanceFunc {
|
|
case "euclidean":
|
|
distanceFunc = pairwise.NewEuclidean()
|
|
case "manhattan":
|
|
distanceFunc = pairwise.NewManhattan()
|
|
default:
|
|
panic("unsupported distance function")
|
|
}
|
|
|
|
for i := 0; i < rows; i++ {
|
|
row := KNN.Data.RowView(i)
|
|
distance := distanceFunc.Distance(utilities.VectorToMatrix(row), vector)
|
|
rownumbers[i] = distance
|
|
}
|
|
|
|
sorted := utilities.SortIntMap(rownumbers)
|
|
values := sorted[:K]
|
|
|
|
var sum float64
|
|
for _, elem := range values {
|
|
value := KNN.Values[elem]
|
|
labels = append(labels, value)
|
|
sum += value
|
|
}
|
|
|
|
average := sum / float64(K)
|
|
return average
|
|
}
|