diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index 7490d94..a0a364b 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -4,6 +4,7 @@ import ( base "github.com/sjwhitworth/golearn/base" meta "github.com/sjwhitworth/golearn/meta" trees "github.com/sjwhitworth/golearn/trees" + "fmt" ) // RandomForest classifies instances using an ensemble @@ -18,8 +19,8 @@ type RandomForest struct { // NewRandomForests generates and return a new random forests // forestSize controls the number of trees that get built // features controls the number of features used to build each tree -func NewRandomForest(forestSize int, features int) RandomForest { - ret := RandomForest{ +func NewRandomForest(forestSize int, features int) *RandomForest { + ret := &RandomForest{ base.BaseClassifier{}, forestSize, features, @@ -43,3 +44,7 @@ func (f *RandomForest) Fit(on *base.Instances) { func (f *RandomForest) Predict(with *base.Instances) *base.Instances { return f.Model.Predict(with) } + +func (f *RandomForest) String() string { + return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model) +} \ No newline at end of file diff --git a/trees/tennis.csv b/examples/datasets/tennis.csv similarity index 100% rename from trees/tennis.csv rename to examples/datasets/tennis.csv diff --git a/examples/trees/trees.go b/examples/trees/trees.go new file mode 100644 index 0000000..676c414 --- /dev/null +++ b/examples/trees/trees.go @@ -0,0 +1,75 @@ +// Demonstrates decision tree classification + +package main + +import ( + "fmt" + base "github.com/sjwhitworth/golearn/base" + eval "github.com/sjwhitworth/golearn/evaluation" + filters "github.com/sjwhitworth/golearn/filters" + ensemble "github.com/sjwhitworth/golearn/ensemble" + trees "github.com/sjwhitworth/golearn/trees" + "math/rand" + "time" +) + +func main () { + + var tree base.Classifier + + rand.Seed(time.Now().UTC().UnixNano()) + + // Load in the iris dataset + iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true) + if err != nil { + panic(err) + } + + // Discretise the iris dataset with Chi-Merge + filt := filters.NewChiMergeFilter(iris, 0.99) + filt.AddAllNumericAttributes() + filt.Build() + filt.Run(iris) + + // Create a 60-40 training-test split + insts := base.InstancesTrainTestSplit(iris, 0.60) + + // + // First up, use ID3 + // + tree = trees.NewID3DecisionTree(0.6) + // (Parameter controls train-prune split.) + + // Train the ID3 tree + tree.Fit(insts[0]) + + // Generate predictions + predictions := tree.Predict(insts[1]) + + // Evaluate + fmt.Println("ID3 Performance") + cf := eval.GetConfusionMatrix(insts[1], predictions) + fmt.Println(eval.GetSummary(cf)) + + // + // Next up, Random Trees + // + + // Consider two randomly-chosen attributes + tree = trees.NewRandomTree(2) + tree.Fit(insts[0]) + predictions = tree.Predict(insts[1]) + fmt.Println("RandomTree Performance") + cf = eval.GetConfusionMatrix(insts[1], predictions) + fmt.Println(eval.GetSummary(cf)) + + // + // Finally, Random Forests + // + tree = ensemble.NewRandomForest(100, 3) + tree.Fit(insts[0]) + predictions = tree.Predict(insts[1]) + fmt.Println("RandomForest Performance") + cf = eval.GetConfusionMatrix(insts[1], predictions) + fmt.Println(eval.GetSummary(cf)) +} \ No newline at end of file diff --git a/trees/tree_test.go b/trees/tree_test.go index 04f96f2..11001bf 100644 --- a/trees/tree_test.go +++ b/trees/tree_test.go @@ -121,7 +121,7 @@ func TestInformationGain(testEnv *testing.T) { func TestID3Inference(testEnv *testing.T) { // Import the "PlayTennis" dataset - inst, err := base.ParseCSVToInstances("./tennis.csv", true) + inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { panic(err) } @@ -198,7 +198,7 @@ func TestID3Classification(testEnv *testing.T) { func TestID3(testEnv *testing.T) { // Import the "PlayTennis" dataset - inst, err := base.ParseCSVToInstances("./tennis.csv", true) + inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true) if err != nil { panic(err) }