mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-30 13:48:57 +08:00
Refactored KNN to implement the estimator interface
This commit is contained in:
parent
3eaeafd0dc
commit
1ade0afca6
@ -15,7 +15,6 @@ import (
|
|||||||
// An object that can ingest some data and train on it.
|
// An object that can ingest some data and train on it.
|
||||||
type Estimator interface {
|
type Estimator interface {
|
||||||
Fit()
|
Fit()
|
||||||
Summarise()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An object that provides predictions.
|
// An object that provides predictions.
|
||||||
|
@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
mat64 "github.com/gonum/matrix/mat64"
|
|
||||||
data "github.com/sjwhitworth/golearn/data"
|
data "github.com/sjwhitworth/golearn/data"
|
||||||
knn "github.com/sjwhitworth/golearn/knn"
|
knn "github.com/sjwhitworth/golearn/knn"
|
||||||
util "github.com/sjwhitworth/golearn/utilities"
|
util "github.com/sjwhitworth/golearn/utilities"
|
||||||
@ -14,17 +13,15 @@ func main() {
|
|||||||
cols, rows, _, labels, data := data.ParseCsv("datasets/iris.csv", 4, []int{0, 1, 2})
|
cols, rows, _, labels, data := data.ParseCsv("datasets/iris.csv", 4, []int{0, 1, 2})
|
||||||
|
|
||||||
//Initialises a new KNN classifier
|
//Initialises a new KNN classifier
|
||||||
cls := knn.NewKnnClassifier(labels, data, rows, cols, "euclidean")
|
cls := knn.NewKnnClassifier("euclidean")
|
||||||
|
cls.Fit(labels, data, rows, cols)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
//Creates a random array of N float64s between 0 and 7
|
//Creates a random array of N float64s between 0 and 7
|
||||||
randArray := util.RandomArray(3, 7)
|
randArray := util.RandomArray(3, 7)
|
||||||
|
|
||||||
//Initialises a vector with this array
|
|
||||||
random := mat64.NewDense(1, 3, randArray)
|
|
||||||
|
|
||||||
//Calculates the Euclidean distance and returns the most popular label
|
//Calculates the Euclidean distance and returns the most popular label
|
||||||
labels := cls.Predict(random, 3)
|
labels := cls.Predict(randArray, 3)
|
||||||
fmt.Println(labels)
|
fmt.Println(labels)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,8 @@ func main() {
|
|||||||
newlabels := util.ConvertLabelsToFloat(labels)
|
newlabels := util.ConvertLabelsToFloat(labels)
|
||||||
|
|
||||||
//Initialises a new KNN classifier
|
//Initialises a new KNN classifier
|
||||||
cls := knn.NewKnnRegressor(newlabels, data, rows, cols, "euclidean")
|
cls := knn.NewKnnRegressor("euclidean")
|
||||||
|
cls.Fit(newlabels, data, rows, cols)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
//Creates a random array of N float64s between 0 and Y
|
//Creates a random array of N float64s between 0 and Y
|
||||||
|
24
knn/knn.go
24
knn/knn.go
@ -19,16 +19,19 @@ type KNNClassifier struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns a new classifier
|
// Returns a new classifier
|
||||||
func NewKnnClassifier(labels []string, numbers []float64, rows int, cols int, distfunc string) *KNNClassifier {
|
func NewKnnClassifier(distfunc string) *KNNClassifier {
|
||||||
|
KNN := KNNClassifier{}
|
||||||
|
KNN.DistanceFunc = distfunc
|
||||||
|
return &KNN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (KNN *KNNClassifier) Fit(labels []string, numbers []float64, rows int, cols int) {
|
||||||
if rows != len(labels) {
|
if rows != len(labels) {
|
||||||
panic(mat64.ErrShape)
|
panic(mat64.ErrShape)
|
||||||
}
|
}
|
||||||
|
|
||||||
KNN := KNNClassifier{}
|
|
||||||
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
||||||
KNN.Labels = labels
|
KNN.Labels = labels
|
||||||
KNN.DistanceFunc = distfunc
|
|
||||||
return &KNN
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
|
||||||
@ -94,14 +97,21 @@ type KNNRegressor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mints a new classifier.
|
// Mints a new classifier.
|
||||||
func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor {
|
func NewKnnRegressor(distfunc string) *KNNRegressor {
|
||||||
KNN := KNNRegressor{}
|
KNN := KNNRegressor{}
|
||||||
KNN.Data = mat64.NewDense(x, y, numbers)
|
|
||||||
KNN.Values = values
|
|
||||||
KNN.DistanceFunc = distfunc
|
KNN.DistanceFunc = distfunc
|
||||||
return &KNN
|
return &KNN
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) {
|
||||||
|
if rows != len(values) {
|
||||||
|
panic(mat64.ErrShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
KNN.Data = mat64.NewDense(rows, cols, numbers)
|
||||||
|
KNN.Values = values
|
||||||
|
}
|
||||||
|
|
||||||
//Returns an average of the K nearest labels/variables, based on a vector input.
|
//Returns an average of the K nearest labels/variables, based on a vector input.
|
||||||
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {
|
||||||
|
|
||||||
|
12
trees/decision_trees.go
Normal file
12
trees/decision_trees.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package trees
|
||||||
|
|
||||||
|
import base "github.com/sjwhitworth/golearn/base"
|
||||||
|
|
||||||
|
type DecisionTree struct {
|
||||||
|
base.BaseEstimator
|
||||||
|
}
|
||||||
|
|
||||||
|
type Branch struct {
|
||||||
|
LeftBranch Branch
|
||||||
|
RightBranch Branch
|
||||||
|
}
|
0
trees/random_forests.go
Normal file
0
trees/random_forests.go
Normal file
2
trees/trees.go
Normal file
2
trees/trees.go
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package trees provides a number of tree based ensemble learners.
|
||||||
|
package trees
|
Loading…
x
Reference in New Issue
Block a user