diff --git a/base/domain.go b/base/domain.go index 852d612..ff247e2 100644 --- a/base/domain.go +++ b/base/domain.go @@ -15,7 +15,6 @@ import ( // An object that can ingest some data and train on it. type Estimator interface { Fit() - Summarise() } // An object that provides predictions. diff --git a/examples/knnclassifier_iris.go b/examples/knnclassifier_iris.go index 5f655c7..60a0b95 100644 --- a/examples/knnclassifier_iris.go +++ b/examples/knnclassifier_iris.go @@ -3,7 +3,6 @@ package main import ( "fmt" - mat64 "github.com/gonum/matrix/mat64" data "github.com/sjwhitworth/golearn/data" knn "github.com/sjwhitworth/golearn/knn" util "github.com/sjwhitworth/golearn/utilities" @@ -14,17 +13,15 @@ func main() { cols, rows, _, labels, data := data.ParseCsv("datasets/iris.csv", 4, []int{0, 1, 2}) //Initialises a new KNN classifier - cls := knn.NewKnnClassifier(labels, data, rows, cols, "euclidean") + cls := knn.NewKnnClassifier("euclidean") + cls.Fit(labels, data, rows, cols) for { //Creates a random array of N float64s between 0 and 7 randArray := util.RandomArray(3, 7) - //Initialises a vector with this array - random := mat64.NewDense(1, 3, randArray) - //Calculates the Euclidean distance and returns the most popular label - labels := cls.Predict(random, 3) + labels := cls.Predict(randArray, 3) fmt.Println(labels) } } diff --git a/examples/knnregressor_random.go b/examples/knnregressor_random.go index 46b16da..c1876d5 100644 --- a/examples/knnregressor_random.go +++ b/examples/knnregressor_random.go @@ -15,7 +15,8 @@ func main() { newlabels := util.ConvertLabelsToFloat(labels) //Initialises a new KNN classifier - cls := knn.NewKnnRegressor(newlabels, data, rows, cols, "euclidean") + cls := knn.NewKnnRegressor("euclidean") + cls.Fit(newlabels, data, rows, cols) for { //Creates a random array of N float64s between 0 and Y diff --git a/knn/knn.go b/knn/knn.go index 4d00590..e63e72b 100644 --- a/knn/knn.go +++ b/knn/knn.go @@ -19,16 +19,19 @@ type KNNClassifier struct { } // Returns a new classifier -func NewKnnClassifier(labels []string, numbers []float64, rows int, cols int, distfunc string) *KNNClassifier { +func NewKnnClassifier(distfunc string) *KNNClassifier { + KNN := KNNClassifier{} + KNN.DistanceFunc = distfunc + return &KNN +} + +func (KNN *KNNClassifier) Fit(labels []string, numbers []float64, rows int, cols int) { if rows != len(labels) { panic(mat64.ErrShape) } - KNN := KNNClassifier{} KNN.Data = mat64.NewDense(rows, cols, numbers) KNN.Labels = labels - KNN.DistanceFunc = distfunc - return &KNN } // Returns a classification for the vector, based on a vector input, using the KNN algorithm. @@ -94,14 +97,21 @@ type KNNRegressor struct { } // Mints a new classifier. -func NewKnnRegressor(values []float64, numbers []float64, x int, y int, distfunc string) *KNNRegressor { +func NewKnnRegressor(distfunc string) *KNNRegressor { KNN := KNNRegressor{} - KNN.Data = mat64.NewDense(x, y, numbers) - KNN.Values = values KNN.DistanceFunc = distfunc return &KNN } +func (KNN *KNNRegressor) Fit(values []float64, numbers []float64, rows int, cols int) { + if rows != len(values) { + panic(mat64.ErrShape) + } + + KNN.Data = mat64.NewDense(rows, cols, numbers) + KNN.Values = values +} + //Returns an average of the K nearest labels/variables, based on a vector input. func (KNN *KNNRegressor) Predict(vector *mat64.Dense, K int) float64 { diff --git a/trees/decision_trees.go b/trees/decision_trees.go new file mode 100644 index 0000000..0bfc6b0 --- /dev/null +++ b/trees/decision_trees.go @@ -0,0 +1,12 @@ +package trees + +import base "github.com/sjwhitworth/golearn/base" + +type DecisionTree struct { + base.BaseEstimator +} + +type Branch struct { + LeftBranch Branch + RightBranch Branch +} diff --git a/trees/random_forests.go b/trees/random_forests.go new file mode 100644 index 0000000..e69de29 diff --git a/trees/trees.go b/trees/trees.go new file mode 100644 index 0000000..f847e5d --- /dev/null +++ b/trees/trees.go @@ -0,0 +1,2 @@ +// Package trees provides a number of tree based ensemble learners. +package trees