1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00

Refactored KNN to implement the estimator interface

This commit is contained in:
Stephen Whitworth 2014-05-05 22:41:55 +01:00
parent 3eaeafd0dc
commit 1ade0afca6
7 changed files with 36 additions and 15 deletions

View File

@ -15,7 +15,6 @@ import (
// An object that can ingest some data and train on it.
type Estimator interface {
Fit()
Summarise()
}
// An object that provides predictions.

View File

@ -3,7 +3,6 @@ package main
import (
"fmt"
mat64 "github.com/gonum/matrix/mat64"
data "github.com/sjwhitworth/golearn/data"
knn "github.com/sjwhitworth/golearn/knn"
util "github.com/sjwhitworth/golearn/utilities"
@ -14,17 +13,15 @@ func main() {
cols, rows, _, labels, data := data.ParseCsv("datasets/iris.csv", 4, []int{0, 1, 2})
//Initialises a new KNN classifier
cls := knn.NewKnnClassifier(labels, data, rows, cols, "euclidean")
cls := knn.NewKnnClassifier("euclidean")
cls.Fit(labels, data, rows, cols)
for {
//Creates a random array of N float64s between 0 and 7
randArray := util.RandomArray(3, 7)
//Initialises a vector with this array
random := mat64.NewDense(1, 3, randArray)
//Calculates the Euclidean distance and returns the most popular label
labels := cls.Predict(random, 3)
labels := cls.Predict(randArray, 3)
fmt.Println(labels)
}
}

View File

@ -15,7 +15,8 @@ func main() {
newlabels := util.ConvertLabelsToFloat(labels)
//Initialises a new KNN classifier
cls := knn.NewKnnRegressor(newlabels, data, rows, cols, "euclidean")
cls := knn.NewKnnRegressor("euclidean")
cls.Fit(newlabels, data, rows, cols)
for {
//Creates a random array of N float64s between 0 and Y

View File

@ -19,16 +19,19 @@ type KNNClassifier struct {
}
// Returns a new classifier
func NewKnnClassifier(labels []string, numbers []float64, rows int, cols int, distfunc string) *KNNClassifier {
func NewKnnClassifier(distfunc string) *KNNClassifier {
KNN := KNNClassifier{}
KNN.DistanceFunc = distfunc
return &KNN
}
func (KNN *KNNClassifier) Fit(labels []string, numbers []float64, rows int, cols int) {
if rows != len(labels) {
panic(mat64.ErrShape)
}
KNN := KNNClassifier{}
KNN.Data = mat64.NewDense(rows, cols, numbers)
KNN.Labels = labels
KNN.DistanceFunc = distfunc
return &KNN
}
// Returns a classification for the vector, based on a vector input, using the KNN algorithm.
@ -94,14 +97,21 @@ type KNNRegressor struct {
}
// Mints a new classifier.
func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor {
func NewKnnRegressor(distfunc string) *KNNRegressor {
KNN := KNNRegressor{}
KNN.Data = mat64.NewDense(x, y, numbers)
KNN.Values = values
KNN.DistanceFunc = distfunc
return &KNN
}
func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) {
if rows != len(values) {
panic(mat64.ErrShape)
}
KNN.Data = mat64.NewDense(rows, cols, numbers)
KNN.Values = values
}
//Returns an average of the K nearest labels/variables, based on a vector input.
func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 {

12
trees/decision_trees.go Normal file
View File

@ -0,0 +1,12 @@
package trees
import base "github.com/sjwhitworth/golearn/base"
type DecisionTree struct {
base.BaseEstimator
}
type Branch struct {
LeftBranch Branch
RightBranch Branch
}

0
trees/random_forests.go Normal file
View File

2
trees/trees.go Normal file
View File

@ -0,0 +1,2 @@
// Package trees provides a number of tree based ensemble learners.
package trees