mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Refactored KNN to implement the estimator interface
This commit is contained in:
parent
3eaeafd0dc
commit
1ade0afca6
@ -15,7 +15,6 @@ import (
|
||||
// An object that can ingest some data and train on it.
|
||||
type Estimator interface {
|
||||
Fit()
|
||||
Summarise()
|
||||
}
|
||||
|
||||
// An object that provides predictions.
|
||||
|
@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
mat64 "github.com/gonum/matrix/mat64"
|
||||
data "github.com/sjwhitworth/golearn/data"
|
||||
knn "github.com/sjwhitworth/golearn/knn"
|
||||
util "github.com/sjwhitworth/golearn/utilities"
|
||||
@ -14,17 +13,15 @@ func main() {
|
||||
cols, rows, _, labels, data := data.ParseCsv("datasets/iris.csv", 4, []int{0, 1, 2})
|
||||
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnClassifier(labels, data, rows, cols, "euclidean")
|
||||
cls := knn.NewKnnClassifier("euclidean")
|
||||
cls.Fit(labels, data, rows, cols)
|
||||
|
||||
for {
|
||||
//Creates a random array of N float64s between 0 and 7
|
||||
randArray := util.RandomArray(3, 7)
|
||||
|
||||
//Initialises a vector with this array
|
||||
random := mat64.NewDense(1, 3, randArray)
|
||||
|
||||
//Calculates the Euclidean distance and returns the most popular label
|
||||
labels := cls.Predict(random, 3)
|
||||
labels := cls.Predict(randArray, 3)
|
||||
fmt.Println(labels)
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,8 @@ func main() {
|
||||
newlabels := util.ConvertLabelsToFloat(labels)
|
||||
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnRegressor(newlabels, data, rows, cols, "euclidean")
|
||||
cls := knn.NewKnnRegressor("euclidean")
|
||||
cls.Fit(newlabels, data, rows, cols)
|
||||
|
||||
for {
|
||||
//Creates a random array of N float64s between 0 and Y
|
||||
|
24
knn/knn.go
24
knn/knn.go
@ -19,16 +19,19 @@ type KNNClassifier struct {
|
||||
}
|
||||
|
||||
// Returns a new classifier
|
||||
func NewKnnClassifier(labels []string, numbers []float64, rows int, cols int, distfunc string) *KNNClassifier {
|
||||
func NewKnnClassifier(distfunc string) *KNNClassifier {
|
||||
KNN := KNNClassifier{}
|
||||
KNN.DistanceFunc = distfunc
|
||||
return &KNN
|
||||
}
|
||||
|
||||
func (KNN *KNNClassifier) Fit(labels []string, numbers []float64, rows int, cols int) {
|
||||
if rows != len(labels) {
|
||||
panic(mat64.ErrShape)
|
||||
}
|
||||
|
||||
KNN := KNNClassifier{}
|
||||
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
||||
KNN.Labels = labels
|
||||
KNN.DistanceFunc = distfunc
|
||||
return &KNN
|
||||
}
|
||||
|
||||
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||
@ -94,14 +97,21 @@ type KNNRegressor struct {
|
||||
}
|
||||
|
||||
// Mints a new classifier.
|
||||
func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor {
|
||||
func NewKnnRegressor(distfunc string) *KNNRegressor {
|
||||
KNN := KNNRegressor{}
|
||||
KNN.Data = mat64.NewDense(x, y, numbers)
|
||||
KNN.Values = values
|
||||
KNN.DistanceFunc = distfunc
|
||||
return &KNN
|
||||
}
|
||||
|
||||
func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) {
|
||||
if rows != len(values) {
|
||||
panic(mat64.ErrShape)
|
||||
}
|
||||
|
||||
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
||||
KNN.Values = values
|
||||
}
|
||||
|
||||
//Returns an average of the K nearest labels/variables, based on a vector input.
|
||||
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
||||
|
||||
|
12
trees/decision_trees.go
Normal file
12
trees/decision_trees.go
Normal file
@ -0,0 +1,12 @@
|
||||
package trees
|
||||
|
||||
import base "github.com/sjwhitworth/golearn/base"
|
||||
|
||||
type DecisionTree struct {
|
||||
base.BaseEstimator
|
||||
}
|
||||
|
||||
type Branch struct {
|
||||
LeftBranch Branch
|
||||
RightBranch Branch
|
||||
}
|
0
trees/random_forests.go
Normal file
0
trees/random_forests.go
Normal file
2
trees/trees.go
Normal file
2
trees/trees.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package trees provides a number of tree based ensemble learners.
|
||||
package trees
|
Loading…
x
Reference in New Issue
Block a user