1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Examples for RandomForest, ID3 and Random trees

This commit is contained in:
Richard Townsend 2014-05-19 12:42:03 +01:00
parent 45ca6063f1
commit 889fec4419
4 changed files with 84 additions and 4 deletions

View File

@ -4,6 +4,7 @@ import (
base "github.com/sjwhitworth/golearn/base"
meta "github.com/sjwhitworth/golearn/meta"
trees "github.com/sjwhitworth/golearn/trees"
"fmt"
)
// RandomForest classifies instances using an ensemble
@ -18,8 +19,8 @@ type RandomForest struct {
// NewRandomForests generates and return a new random forests
// forestSize controls the number of trees that get built
// features controls the number of features used to build each tree
func NewRandomForest(forestSize int, features int) RandomForest {
ret := RandomForest{
func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{
base.BaseClassifier{},
forestSize,
features,
@ -43,3 +44,7 @@ func (f *RandomForest) Fit(on *base.Instances) {
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
return f.Model.Predict(with)
}
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}

75
examples/trees/trees.go Normal file
View File

@ -0,0 +1,75 @@
// Demonstrates decision tree classification
package main
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
ensemble "github.com/sjwhitworth/golearn/ensemble"
trees "github.com/sjwhitworth/golearn/trees"
"math/rand"
"time"
)
func main () {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
// Create a 60-40 training-test split
insts := base.InstancesTrainTestSplit(iris, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(insts[0])
// Generate predictions
predictions := tree.Predict(insts[1])
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
}

View File

@ -121,7 +121,7 @@ func TestInformationGain(testEnv *testing.T) {
func TestID3Inference(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
panic(err)
}
@ -198,7 +198,7 @@ func TestID3Classification(testEnv *testing.T) {
func TestID3(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("./tennis.csv", true)
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
panic(err)
}