mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Examples for RandomForest, ID3 and Random trees
This commit is contained in:
parent
45ca6063f1
commit
889fec4419
@ -4,6 +4,7 @@ import (
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
meta "github.com/sjwhitworth/golearn/meta"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RandomForest classifies instances using an ensemble
|
||||
@ -18,8 +19,8 @@ type RandomForest struct {
|
||||
// NewRandomForests generates and return a new random forests
|
||||
// forestSize controls the number of trees that get built
|
||||
// features controls the number of features used to build each tree
|
||||
func NewRandomForest(forestSize int, features int) RandomForest {
|
||||
ret := RandomForest{
|
||||
func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
ret := &RandomForest{
|
||||
base.BaseClassifier{},
|
||||
forestSize,
|
||||
features,
|
||||
@ -43,3 +44,7 @@ func (f *RandomForest) Fit(on *base.Instances) {
|
||||
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
||||
return f.Model.Predict(with)
|
||||
}
|
||||
|
||||
func (f *RandomForest) String() string {
|
||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||
}
|
75
examples/trees/trees.go
Normal file
75
examples/trees/trees.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Demonstrates decision tree classification
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
ensemble "github.com/sjwhitworth/golearn/ensemble"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main () {
|
||||
|
||||
var tree base.Classifier
|
||||
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
|
||||
// Load in the iris dataset
|
||||
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Discretise the iris dataset with Chi-Merge
|
||||
filt := filters.NewChiMergeFilter(iris, 0.99)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(iris)
|
||||
|
||||
// Create a 60-40 training-test split
|
||||
insts := base.InstancesTrainTestSplit(iris, 0.60)
|
||||
|
||||
//
|
||||
// First up, use ID3
|
||||
//
|
||||
tree = trees.NewID3DecisionTree(0.6)
|
||||
// (Parameter controls train-prune split.)
|
||||
|
||||
// Train the ID3 tree
|
||||
tree.Fit(insts[0])
|
||||
|
||||
// Generate predictions
|
||||
predictions := tree.Predict(insts[1])
|
||||
|
||||
// Evaluate
|
||||
fmt.Println("ID3 Performance")
|
||||
cf := eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Next up, Random Trees
|
||||
//
|
||||
|
||||
// Consider two randomly-chosen attributes
|
||||
tree = trees.NewRandomTree(2)
|
||||
tree.Fit(insts[0])
|
||||
predictions = tree.Predict(insts[1])
|
||||
fmt.Println("RandomTree Performance")
|
||||
cf = eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Finally, Random Forests
|
||||
//
|
||||
tree = ensemble.NewRandomForest(100, 3)
|
||||
tree.Fit(insts[0])
|
||||
predictions = tree.Predict(insts[1])
|
||||
fmt.Println("RandomForest Performance")
|
||||
cf = eval.GetConfusionMatrix(insts[1], predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
}
|
@ -121,7 +121,7 @@ func TestInformationGain(testEnv *testing.T) {
|
||||
func TestID3Inference(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -198,7 +198,7 @@ func TestID3Classification(testEnv *testing.T) {
|
||||
func TestID3(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user