mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
commit
dbf1c9a6b3
53
README.md
53
README.md
@ -25,31 +25,46 @@ GoLearn implements the scikit-learn interface of Fit/Predict, so you can easily
|
||||
GoLearn also includes helper functions for data, like cross validation, and train and test splitting.
|
||||
|
||||
```go
|
||||
// Load in a dataset, with headers. Header attributes will be stored.
|
||||
// Think of instances as a Data Frame structure in R or Pandas.
|
||||
// You can also create instances from scratch.
|
||||
data, err := base.ParseCSVToInstances("datasets/iris_headers.csv", true)
|
||||
package main
|
||||
|
||||
// Print a pleasant summary of your data.
|
||||
fmt.Println(data)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
// Split your dataframe into a training set, and a test set, with an 80/20 proportion.
|
||||
trainTest := base.InstancesTrainTestSplit(rawData, 0.8)
|
||||
trainData := trainTest[0]
|
||||
testData := trainTest[1]
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/knn"
|
||||
)
|
||||
|
||||
// Instantiate a new KNN classifier. Euclidean distance, with 2 neighbours.
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
func main() {
|
||||
// Load in a dataset, with headers. Header attributes will be stored.
|
||||
// Think of instances as a Data Frame structure in R or Pandas.
|
||||
// You can also create instances from scratch.
|
||||
rawData, err := base.ParseCSVToInstances("datasets/iris.csv", false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Fit it on your training data.
|
||||
cls.Fit(trainData)
|
||||
// Print a pleasant summary of your data.
|
||||
fmt.Println(rawData)
|
||||
|
||||
// Get your predictions against test instances.
|
||||
predictions := cls.Predict(testData)
|
||||
//Initialises a new KNN classifier
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
|
||||
// Print a confusion matrix with precision and recall metrics.
|
||||
confusionMat, _ := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(evaluation.GetSummary(confusionMat))
|
||||
//Do a training-test split
|
||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||
cls.Fit(trainData)
|
||||
|
||||
//Calculates the Euclidean distance and returns the most popular label
|
||||
predictions := cls.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
|
||||
// Prints precision/recall metrics
|
||||
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
|
||||
}
|
||||
fmt.Println(evaluation.GetSummary(confusionMat))
|
||||
}
|
||||
```
|
||||
|
||||
```
|
||||
|
Loading…
x
Reference in New Issue
Block a user