mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Merge branch 'master' into feature/naive
This commit is contained in:
commit
0cf6d258e6
@ -4,6 +4,9 @@ go:
|
||||
- 1.2
|
||||
- release
|
||||
- tip
|
||||
before_install:
|
||||
- sudo apt-get update -qq
|
||||
- sudo apt-get install -qq libatlas-base-dev
|
||||
install:
|
||||
- go get github.com/smartystreets/goconvey/convey
|
||||
- go get -v ./...
|
||||
- go get -v ./...
|
||||
|
@ -2,7 +2,8 @@ GoLearn
|
||||
=======
|
||||
|
||||
<img src="http://talks.golang.org/2013/advconc/gopherhat.jpg" width=125><br>
|
||||
[](https://godoc.org/github.com/sjwhitworth/golearn)<br>
|
||||
[](https://godoc.org/github.com/sjwhitworth/golearn)
|
||||
[](https://travis-ci.org/sjwhitworth/golearn)<br>
|
||||
|
||||
GoLearn is a 'batteries included' machine learning library for Go. **Simplicity**, paired with customisability, is the goal.
|
||||
We are in active development, and would love comments from users out in the wild. Drop us a line on Twitter.
|
||||
|
@ -219,7 +219,7 @@ func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) float64 {
|
||||
// Returns a string containing the list of human-readable values this
|
||||
// CategoricalAttribute can take.
|
||||
func (Attr *CategoricalAttribute) String() string {
|
||||
return fmt.Sprintf("CategoricalAttribute(%s)", Attr.values)
|
||||
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
|
||||
}
|
||||
|
||||
// GetStringFromSysVal returns a human-readable value from the given system-representation
|
||||
|
@ -5,8 +5,9 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gonum/matrix/mat64"
|
||||
"math/rand"
|
||||
|
||||
"github.com/gonum/matrix/mat64"
|
||||
)
|
||||
|
||||
// SortDirection specifies sorting direction...
|
||||
@ -171,10 +172,11 @@ func NewInstancesFromDense(attrs []Attribute, rows int, mat *mat64.Dense) *Insta
|
||||
//
|
||||
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
|
||||
// Using any other values may result in odd behaviour.
|
||||
func InstancesTrainTestSplit(src *Instances, prop float64) [2](*Instances) {
|
||||
func InstancesTrainTestSplit(src *Instances, prop float64) (*Instances, *Instances) {
|
||||
trainingRows := make([]int, 0)
|
||||
testingRows := make([]int, 0)
|
||||
numAttrs := len(src.attributes)
|
||||
src.Shuffle()
|
||||
for i := 0; i < src.Rows; i++ {
|
||||
trainOrTest := rand.Intn(101)
|
||||
if trainOrTest > int(100*prop) {
|
||||
@ -196,10 +198,10 @@ func InstancesTrainTestSplit(src *Instances, prop float64) [2](*Instances) {
|
||||
rawTestMatrix.SetRow(i, rowDat)
|
||||
}
|
||||
|
||||
var ret [2]*Instances
|
||||
ret[0] = NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
|
||||
ret[1] = NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
|
||||
return ret
|
||||
|
||||
trainingRet := NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
|
||||
testRet := NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
|
||||
return trainingRet, testRet
|
||||
}
|
||||
|
||||
// CountAttrValues returns the distribution of values of a given
|
||||
@ -273,6 +275,33 @@ func (inst *Instances) DecomposeOnAttributeValues(at Attribute) map[string]*Inst
|
||||
return ret
|
||||
}
|
||||
|
||||
func (inst *Instances) GetClassDistributionAfterSplit(at Attribute) map[string]map[string]int {
|
||||
|
||||
ret := make(map[string]map[string]int)
|
||||
|
||||
// Find the attribute we're decomposing on
|
||||
attrIndex := inst.GetAttrIndex(at)
|
||||
if attrIndex == -1 {
|
||||
panic("Invalid attribute index")
|
||||
}
|
||||
|
||||
// Get the class index
|
||||
classAttr := inst.GetAttr(inst.ClassIndex)
|
||||
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
splitVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
|
||||
classVar := classAttr.GetStringFromSysVal(inst.Get(i, inst.ClassIndex))
|
||||
if _, ok := ret[splitVar]; !ok {
|
||||
ret[splitVar] = make(map[string]int)
|
||||
i--
|
||||
continue
|
||||
}
|
||||
ret[splitVar][classVar]++
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Get returns the system representation (float64) of the value
|
||||
// stored at the given row and col coordinate.
|
||||
func (inst *Instances) Get(row int, col int) float64 {
|
||||
@ -308,6 +337,20 @@ func (inst *Instances) GetClass(row int) string {
|
||||
return attr.GetStringFromSysVal(val)
|
||||
}
|
||||
|
||||
// GetClassDist returns a map containing the count of each
|
||||
// class type (indexed by the class' string representation)
|
||||
func (inst *Instances) GetClassDistribution() map[string]int {
|
||||
ret := make(map[string]int)
|
||||
attr := inst.GetAttr(inst.ClassIndex)
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
val := inst.Get(i, inst.ClassIndex)
|
||||
cls := attr.GetStringFromSysVal(val)
|
||||
ret[cls]++
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (Inst *Instances) GetClassAttrPtr() *Attribute {
|
||||
attr := Inst.GetAttr(Inst.ClassIndex)
|
||||
return &attr
|
||||
@ -465,7 +508,7 @@ func (inst *Instances) GeneratePredictionVector() *Instances {
|
||||
// Shuffle randomizes the row order in place
|
||||
func (inst *Instances) Shuffle() {
|
||||
for i := 0; i < inst.Rows; i++ {
|
||||
j := rand.Intn(inst.Rows)
|
||||
j := rand.Intn(i + 1)
|
||||
inst.swapRows(i, j)
|
||||
}
|
||||
}
|
||||
|
13
ensemble/ensemble.go
Normal file
13
ensemble/ensemble.go
Normal file
@ -0,0 +1,13 @@
|
||||
/*
|
||||
|
||||
Ensemble contains classifiers which combine other classifiers.
|
||||
|
||||
RandomForest:
|
||||
Generates ForestSize bagged decision trees (currently ID3-based)
|
||||
each considering a fixed number of random features.
|
||||
|
||||
Built on meta.Bagging
|
||||
|
||||
*/
|
||||
|
||||
package ensemble
|
50
ensemble/randomforest.go
Normal file
50
ensemble/randomforest.go
Normal file
@ -0,0 +1,50 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
meta "github.com/sjwhitworth/golearn/meta"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RandomForest classifies instances using an ensemble
|
||||
// of bagged random decision trees
|
||||
type RandomForest struct {
|
||||
base.BaseClassifier
|
||||
ForestSize int
|
||||
Features int
|
||||
Model *meta.BaggedModel
|
||||
}
|
||||
|
||||
// NewRandomForests generates and return a new random forests
|
||||
// forestSize controls the number of trees that get built
|
||||
// features controls the number of features used to build each tree
|
||||
func NewRandomForest(forestSize int, features int) *RandomForest {
|
||||
ret := &RandomForest{
|
||||
base.BaseClassifier{},
|
||||
forestSize,
|
||||
features,
|
||||
nil,
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Train builds the RandomForest on the specified instances
|
||||
func (f *RandomForest) Fit(on *base.Instances) {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
f.Model.RandomFeatures = f.Features
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
tree := trees.NewID3DecisionTree(0.00)
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
f.Model.Fit(on)
|
||||
}
|
||||
|
||||
// Predict generates predictions from a trained RandomForest
|
||||
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
|
||||
return f.Model.Predict(with)
|
||||
}
|
||||
|
||||
func (f *RandomForest) String() string {
|
||||
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
|
||||
}
|
29
ensemble/randomforest_test.go
Normal file
29
ensemble/randomforest_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package ensemble
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRandomForest1(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
|
||||
filt := filters.NewChiMergeFilter(trainData, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
rf := NewRandomForest(10, 3)
|
||||
rf.Fit(trainData)
|
||||
predictions := rf.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
15
examples/datasets/tennis.csv
Normal file
15
examples/datasets/tennis.csv
Normal file
@ -0,0 +1,15 @@
|
||||
outlook,temp,humidity,windy,play
|
||||
sunny,hot,high,false,no
|
||||
sunny,hot,high,true,no
|
||||
overcast,hot,high,false,yes
|
||||
rainy,mild,high,false,yes
|
||||
rainy,cool,normal,false,yes
|
||||
rainy,cool,normal,true,no
|
||||
overcast,cool,normal,true,yes
|
||||
sunny,mild,high,false,no
|
||||
sunny,cool,normal,false,yes
|
||||
rainy,mild,normal,false,yes
|
||||
sunny,mild,normal,true,yes
|
||||
overcast,mild,high,true,yes
|
||||
overcast,hot,normal,false,yes
|
||||
rainy,mild,high,true,no
|
|
@ -17,9 +17,7 @@ func main() {
|
||||
cls := knn.NewKnnClassifier("euclidean", 2)
|
||||
|
||||
//Do a training-test split
|
||||
trainTest := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||
trainData := trainTest[0]
|
||||
testData := trainTest[1]
|
||||
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
|
||||
cls.Fit(trainData)
|
||||
|
||||
//Calculates the Euclidean distance and returns the most popular label
|
||||
|
75
examples/trees/trees.go
Normal file
75
examples/trees/trees.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Demonstrates decision tree classification
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
ensemble "github.com/sjwhitworth/golearn/ensemble"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main () {
|
||||
|
||||
var tree base.Classifier
|
||||
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
|
||||
// Load in the iris dataset
|
||||
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Discretise the iris dataset with Chi-Merge
|
||||
filt := filters.NewChiMergeFilter(iris, 0.99)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(iris)
|
||||
|
||||
// Create a 60-40 training-test split
|
||||
trainData, testData := base.InstancesTrainTestSplit(iris, 0.60)
|
||||
|
||||
//
|
||||
// First up, use ID3
|
||||
//
|
||||
tree = trees.NewID3DecisionTree(0.6)
|
||||
// (Parameter controls train-prune split.)
|
||||
|
||||
// Train the ID3 tree
|
||||
tree.Fit(trainData)
|
||||
|
||||
// Generate predictions
|
||||
predictions := tree.Predict(testData)
|
||||
|
||||
// Evaluate
|
||||
fmt.Println("ID3 Performance")
|
||||
cf := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Next up, Random Trees
|
||||
//
|
||||
|
||||
// Consider two randomly-chosen attributes
|
||||
tree = trees.NewRandomTree(2)
|
||||
tree.Fit(testData)
|
||||
predictions = tree.Predict(testData)
|
||||
fmt.Println("RandomTree Performance")
|
||||
cf = eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
|
||||
//
|
||||
// Finally, Random Forests
|
||||
//
|
||||
tree = ensemble.NewRandomForest(100, 3)
|
||||
tree.Fit(trainData)
|
||||
predictions = tree.Predict(testData)
|
||||
fmt.Println("RandomForest Performance")
|
||||
cf = eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(eval.GetSummary(cf))
|
||||
}
|
@ -43,6 +43,21 @@ func (c *ChiMergeFilter) Build() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddAllNumericAttributes adds every suitable attribute
|
||||
// to the ChiMergeFilter for discretisation
|
||||
func (b *ChiMergeFilter) AddAllNumericAttributes() {
|
||||
for i := 0; i < b.Instances.Cols; i++ {
|
||||
if i == b.Instances.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := b.Instances.GetAttr(i)
|
||||
if attr.GetType() != base.Float64Type {
|
||||
continue
|
||||
}
|
||||
b.Attributes = append(b.Attributes, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Run discretises the set of Instances `on'
|
||||
//
|
||||
// IMPORTANT: ChiMergeFilter discretises in place.
|
||||
|
190
meta/bagging.go
Normal file
190
meta/bagging.go
Normal file
@ -0,0 +1,190 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BaggedModel trains base.Classifiers on subsets of the original
|
||||
// Instances and combine the results through voting
|
||||
type BaggedModel struct {
|
||||
base.BaseClassifier
|
||||
Models []base.Classifier
|
||||
RandomFeatures int
|
||||
lock sync.Mutex
|
||||
selectedAttributes map[int][]base.Attribute
|
||||
}
|
||||
|
||||
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
|
||||
// the provided base.Instances.
|
||||
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
|
||||
ret := make([]base.Attribute, 0)
|
||||
if b.RandomFeatures == 0 {
|
||||
for j := 0; j < from.Cols; j++ {
|
||||
attr := from.GetAttr(j)
|
||||
ret = append(ret, attr)
|
||||
}
|
||||
} else {
|
||||
for {
|
||||
if len(ret) >= b.RandomFeatures {
|
||||
break
|
||||
}
|
||||
attrIndex := rand.Intn(from.Cols)
|
||||
if attrIndex == from.ClassIndex {
|
||||
continue
|
||||
}
|
||||
attr := from.GetAttr(attrIndex)
|
||||
matched := false
|
||||
for _, a := range ret {
|
||||
if a.Equals(attr) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
ret = append(ret, attr)
|
||||
}
|
||||
}
|
||||
}
|
||||
ret = append(ret, from.GetClassAttr())
|
||||
b.lock.Lock()
|
||||
b.selectedAttributes[model] = ret
|
||||
b.lock.Unlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
// generatePredictionInstances returns a modified version of the
|
||||
// requested base.Instances with only the base.Attributes selected
|
||||
// for training the model.
|
||||
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
|
||||
selected := b.selectedAttributes[model]
|
||||
return from.SelectAttributes(selected)
|
||||
}
|
||||
|
||||
// generateTrainingInstances generates RandomFeatures number of
|
||||
// attributes and returns a modified version of base.Instances
|
||||
// for training the model
|
||||
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
|
||||
insts := from.SampleWithReplacement(from.Rows)
|
||||
selected := b.generateTrainingAttrs(model, from)
|
||||
return insts.SelectAttributes(selected)
|
||||
}
|
||||
|
||||
// AddModel adds a base.Classifier to the current model
|
||||
func (b *BaggedModel) AddModel(m base.Classifier) {
|
||||
b.Models = append(b.Models, m)
|
||||
}
|
||||
|
||||
// Train generates and trains each model on a randomised subset of
|
||||
// Instances.
|
||||
func (b *BaggedModel) Fit(from *base.Instances) {
|
||||
var wait sync.WaitGroup
|
||||
b.selectedAttributes = make(map[int][]base.Attribute)
|
||||
for i, m := range b.Models {
|
||||
wait.Add(1)
|
||||
go func(c base.Classifier, f *base.Instances, model int) {
|
||||
l := b.generateTrainingInstances(model, f)
|
||||
c.Fit(l)
|
||||
wait.Done()
|
||||
}(m, from, i)
|
||||
}
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
// Predict gathers predictions from all the classifiers
|
||||
// and outputs the most common (majority) class
|
||||
//
|
||||
// IMPORTANT: in the event of a tie, the first class which
|
||||
// achieved the tie value is output.
|
||||
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
|
||||
n := runtime.NumCPU()
|
||||
// Channel to receive the results as they come in
|
||||
votes := make(chan *base.Instances, n)
|
||||
// Count the votes for each class
|
||||
voting := make(map[int](map[string]int))
|
||||
|
||||
// Create a goroutine to collect the votes
|
||||
var votingwait sync.WaitGroup
|
||||
votingwait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
incoming, ok := <-votes
|
||||
if ok {
|
||||
// Step through each prediction
|
||||
for j := 0; j < incoming.Rows; j++ {
|
||||
// Check if we've seen this class before...
|
||||
if _, ok := voting[j]; !ok {
|
||||
// If we haven't, create an entry
|
||||
voting[j] = make(map[string]int)
|
||||
// Continue on the current row
|
||||
j--
|
||||
continue
|
||||
}
|
||||
voting[j][incoming.GetClass(j)]++
|
||||
}
|
||||
} else {
|
||||
votingwait.Done()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create workers to process the predictions
|
||||
processpipe := make(chan int, n)
|
||||
var processwait sync.WaitGroup
|
||||
for i := 0; i < n; i++ {
|
||||
processwait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
if i, ok := <-processpipe; ok {
|
||||
c := b.Models[i]
|
||||
l := b.generatePredictionInstances(i, from)
|
||||
votes <- c.Predict(l)
|
||||
} else {
|
||||
processwait.Done()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send all the models to the workers for prediction
|
||||
for i, _ := range b.Models {
|
||||
processpipe <- i
|
||||
}
|
||||
close(processpipe) // Finished sending models to be predicted
|
||||
processwait.Wait() // Predictors all finished processing
|
||||
close(votes) // Close the vote channel and allow it to drain
|
||||
votingwait.Wait() // All the votes are in
|
||||
|
||||
// Generate the overall consensus
|
||||
ret := from.GeneratePredictionVector()
|
||||
for i := range voting {
|
||||
maxClass := ""
|
||||
maxCount := 0
|
||||
// Find the most popular class
|
||||
for c := range voting[i] {
|
||||
votes := voting[i][c]
|
||||
if votes > maxCount {
|
||||
maxClass = c
|
||||
maxCount = votes
|
||||
}
|
||||
}
|
||||
ret.SetAttrStr(i, 0, maxClass)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the
|
||||
// BaggedModel and everything it contains
|
||||
func (b *BaggedModel) String() string {
|
||||
children := make([]string, 0)
|
||||
for i, m := range b.Models {
|
||||
children = append(children, fmt.Sprintf("%d: %s", i, m))
|
||||
}
|
||||
return fmt.Sprintf("BaggedModel(\n%s)", strings.Join(children, "\n\t"))
|
||||
}
|
83
meta/bagging_test.go
Normal file
83
meta/bagging_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
trees "github.com/sjwhitworth/golearn/trees"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Fit(inst)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(inst)
|
||||
testEnv.ResetTimer()
|
||||
for i := 0; i < 20; i++ {
|
||||
rf.Predict(inst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomForest1(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
rf := new(BaggedModel)
|
||||
for i := 0; i < 10; i++ {
|
||||
rf.AddModel(trees.NewRandomTree(2))
|
||||
}
|
||||
rf.Fit(trainData)
|
||||
fmt.Println(rf)
|
||||
predictions := rf.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package trees
|
||||
|
||||
import base "github.com/sjwhitworth/golearn/base"
|
||||
|
||||
type DecisionTree struct {
|
||||
base.BaseEstimator
|
||||
}
|
||||
|
||||
type Branch struct {
|
||||
LeftBranch Branch
|
||||
RightBranch Branch
|
||||
}
|
100
trees/entropy.go
Normal file
100
trees/entropy.go
Normal file
@ -0,0 +1,100 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math"
|
||||
)
|
||||
|
||||
//
|
||||
// Information gain rule generator
|
||||
//
|
||||
|
||||
type InformationGainRuleGenerator struct {
|
||||
}
|
||||
|
||||
// GetSplitAttribute returns the non-class Attribute which maximises the
|
||||
// information gain.
|
||||
//
|
||||
// IMPORTANT: passing a base.Instances with no Attributes other than the class
|
||||
// variable will panic()
|
||||
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
allAttributes := make([]int, 0)
|
||||
for i := 0; i < f.Cols; i++ {
|
||||
if i != f.ClassIndex {
|
||||
allAttributes = append(allAttributes, i)
|
||||
}
|
||||
}
|
||||
return r.GetSplitAttributeFromSelection(allAttributes, f)
|
||||
}
|
||||
|
||||
// GetSplitAttribute from selection returns the class Attribute which maximises
|
||||
// the information gain amongst consideredAttributes
|
||||
//
|
||||
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
|
||||
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
|
||||
|
||||
// Next step is to compute the information gain at this node
|
||||
// for each randomly chosen attribute, and pick the one
|
||||
// which maximises it
|
||||
maxGain := math.Inf(-1)
|
||||
selectedAttribute := -1
|
||||
|
||||
// Compute the base entropy
|
||||
classDist := f.GetClassDistribution()
|
||||
baseEntropy := getBaseEntropy(classDist)
|
||||
|
||||
// Compute the information gain for each attribute
|
||||
for _, s := range consideredAttributes {
|
||||
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
|
||||
localEntropy := getSplitEntropy(proposedClassDist)
|
||||
informationGain := baseEntropy - localEntropy
|
||||
if informationGain > maxGain {
|
||||
maxGain = informationGain
|
||||
selectedAttribute = s
|
||||
}
|
||||
}
|
||||
|
||||
// Pick the one which maximises IG
|
||||
return f.GetAttr(selectedAttribute)
|
||||
}
|
||||
|
||||
//
|
||||
// Entropy functions
|
||||
//
|
||||
|
||||
// getSplitEntropy determines the entropy of the target
|
||||
// class distribution after splitting on an base.Attribute
|
||||
func getSplitEntropy(s map[string]map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for a := range s {
|
||||
for c := range s[a] {
|
||||
count += s[a][c]
|
||||
}
|
||||
}
|
||||
for a := range s {
|
||||
total := 0.0
|
||||
for c := range s[a] {
|
||||
total += float64(s[a][c])
|
||||
}
|
||||
for c := range s[a] {
|
||||
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
|
||||
}
|
||||
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// getBaseEntropy determines the entropy of the target
|
||||
// class distribution before splitting on an base.Attribute
|
||||
func getBaseEntropy(s map[string]int) float64 {
|
||||
ret := 0.0
|
||||
count := 0
|
||||
for k := range s {
|
||||
count += s[k]
|
||||
}
|
||||
for k := range s {
|
||||
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
|
||||
}
|
||||
return ret
|
||||
}
|
267
trees/id3.go
Normal file
267
trees/id3.go
Normal file
@ -0,0 +1,267 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not
|
||||
type NodeType int
|
||||
|
||||
const (
|
||||
// LeafNode means there are no children
|
||||
LeafNode NodeType = 1
|
||||
// RuleNode means we should look at the next attribute value
|
||||
RuleNode NodeType = 2
|
||||
)
|
||||
|
||||
// RuleGenerator implementations analyse instances and determine
|
||||
// the best value to split on
|
||||
type RuleGenerator interface {
|
||||
GenerateSplitAttribute(*base.Instances) base.Attribute
|
||||
}
|
||||
|
||||
// DecisionTreeNode represents a given portion of a decision tree
|
||||
type DecisionTreeNode struct {
|
||||
Type NodeType
|
||||
Children map[string]*DecisionTreeNode
|
||||
SplitAttr base.Attribute
|
||||
ClassDist map[string]int
|
||||
Class string
|
||||
ClassAttr *base.Attribute
|
||||
}
|
||||
|
||||
// InferID3Tree builds a decision tree using a RuleGenerator
|
||||
// from a set of Instances (implements the ID3 algorithm)
|
||||
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
|
||||
// Count the number of classes at this node
|
||||
classes := from.CountClassValues()
|
||||
// If there's only one class, return a DecisionTreeLeaf with
|
||||
// the only class available
|
||||
if len(classes) == 1 {
|
||||
maxClass := ""
|
||||
for i := range classes {
|
||||
maxClass = i
|
||||
}
|
||||
ret := &DecisionTreeNode{
|
||||
LeafNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Only have the class attribute
|
||||
maxVal := 0
|
||||
maxClass := ""
|
||||
for i := range classes {
|
||||
if classes[i] > maxVal {
|
||||
maxClass = i
|
||||
maxVal = classes[i]
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no more Attributes left to split on,
|
||||
// return a DecisionTreeLeaf with the majority class
|
||||
if from.GetAttributeCount() == 2 {
|
||||
ret := &DecisionTreeNode{
|
||||
LeafNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
ret := &DecisionTreeNode{
|
||||
RuleNode,
|
||||
nil,
|
||||
nil,
|
||||
classes,
|
||||
maxClass,
|
||||
from.GetClassAttrPtr(),
|
||||
}
|
||||
|
||||
// Generate a return structure
|
||||
// Generate the splitting attribute
|
||||
splitOnAttribute := with.GenerateSplitAttribute(from)
|
||||
if splitOnAttribute == nil {
|
||||
// Can't determine, just return what we have
|
||||
return ret
|
||||
}
|
||||
// Split the attributes based on this attribute's value
|
||||
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
|
||||
// Create new children from these attributes
|
||||
ret.Children = make(map[string]*DecisionTreeNode)
|
||||
for k := range splitInstances {
|
||||
newInstances := splitInstances[k]
|
||||
ret.Children[k] = InferID3Tree(newInstances, with)
|
||||
}
|
||||
ret.SplitAttr = splitOnAttribute
|
||||
return ret
|
||||
}
|
||||
|
||||
// getNestedString returns the contents of node d
|
||||
// prefixed by level number of tags (also prints children)
|
||||
func (d *DecisionTreeNode) getNestedString(level int) string {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
tmp := bytes.NewBuffer(nil)
|
||||
for i := 0; i < level; i++ {
|
||||
tmp.WriteString("\t")
|
||||
}
|
||||
buf.WriteString(tmp.String())
|
||||
if d.Children == nil {
|
||||
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
|
||||
keys := make([]string, 0)
|
||||
for k := range d.Children {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(tmp.String())
|
||||
buf.WriteString("\t")
|
||||
buf.WriteString(k)
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(d.Children[k].getNestedString(level + 1))
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of a given node
|
||||
// and it's children
|
||||
func (d *DecisionTreeNode) String() string {
|
||||
return d.getNestedString(0)
|
||||
}
|
||||
|
||||
// computeAccuracy is a helper method for Prune()
|
||||
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
|
||||
cf := eval.GetConfusionMatrix(from, predictions)
|
||||
return eval.GetAccuracy(cf)
|
||||
}
|
||||
|
||||
// Prune eliminates branches which hurt accuracy
|
||||
func (d *DecisionTreeNode) Prune(using *base.Instances) {
|
||||
// If you're a leaf, you're already pruned
|
||||
if d.Children == nil {
|
||||
return
|
||||
} else {
|
||||
if d.SplitAttr == nil {
|
||||
return
|
||||
}
|
||||
// Recursively prune children of this node
|
||||
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
|
||||
for k := range d.Children {
|
||||
if sub[k] == nil {
|
||||
continue
|
||||
}
|
||||
d.Children[k].Prune(sub[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Get a baseline accuracy
|
||||
baselineAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
|
||||
// Speculatively remove the children and re-evaluate
|
||||
tmpChildren := d.Children
|
||||
d.Children = nil
|
||||
newAccuracy := computeAccuracy(d.Predict(using), using)
|
||||
|
||||
// Keep the children removed if better, else restore
|
||||
if newAccuracy < baselineAccuracy {
|
||||
d.Children = tmpChildren
|
||||
}
|
||||
}
|
||||
|
||||
// Predict outputs a base.Instances containing predictions from this tree
|
||||
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
|
||||
outputAttrs := make([]base.Attribute, 1)
|
||||
outputAttrs[0] = what.GetClassAttr()
|
||||
predictions := base.NewInstances(outputAttrs, what.Rows)
|
||||
for i := 0; i < what.Rows; i++ {
|
||||
cur := d
|
||||
for {
|
||||
if cur.Children == nil {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
break
|
||||
} else {
|
||||
at := cur.SplitAttr
|
||||
j := what.GetAttrIndex(at)
|
||||
if j == -1 {
|
||||
predictions.SetAttrStr(i, 0, cur.Class)
|
||||
break
|
||||
}
|
||||
classVar := at.GetStringFromSysVal(what.Get(i, j))
|
||||
if next, ok := cur.Children[classVar]; ok {
|
||||
cur = next
|
||||
} else {
|
||||
var bestChild string
|
||||
for c := range cur.Children {
|
||||
bestChild = c
|
||||
if c > classVar {
|
||||
break
|
||||
}
|
||||
}
|
||||
cur = cur.Children[bestChild]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return predictions
|
||||
}
|
||||
|
||||
//
|
||||
// ID3 Tree type
|
||||
//
|
||||
|
||||
// ID3DecisionTree represents an ID3-based decision tree
|
||||
// using the Information Gain metric to select which attributes
|
||||
// to split on at each node.
|
||||
type ID3DecisionTree struct {
|
||||
base.BaseClassifier
|
||||
Root *DecisionTreeNode
|
||||
PruneSplit float64
|
||||
}
|
||||
|
||||
// Returns a new ID3DecisionTree with the specified test-prune
|
||||
// ratio. Of the ratio is less than 0.001, the tree isn't pruned
|
||||
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
|
||||
return &ID3DecisionTree{
|
||||
base.BaseClassifier{},
|
||||
nil,
|
||||
prune,
|
||||
}
|
||||
}
|
||||
|
||||
// Fit builds the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Fit(on *base.Instances) {
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
if t.PruneSplit > 0.001 {
|
||||
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
|
||||
t.Root = InferID3Tree(trainData, rule)
|
||||
t.Root.Prune(testData)
|
||||
} else {
|
||||
t.Root = InferID3Tree(on, rule)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict outputs predictions from the ID3 decision tree
|
||||
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
|
||||
return t.Root.Predict(what)
|
||||
}
|
||||
|
||||
// String returns a human-readable version of this ID3 tree
|
||||
func (t *ID3DecisionTree) String() string {
|
||||
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
|
||||
}
|
88
trees/random.go
Normal file
88
trees/random.go
Normal file
@ -0,0 +1,88 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
|
||||
type RandomTreeRuleGenerator struct {
|
||||
Attributes int
|
||||
internalRule InformationGainRuleGenerator
|
||||
}
|
||||
|
||||
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
|
||||
// which maximises Information Gain
|
||||
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
|
||||
|
||||
// First step is to generate the random attributes that we'll consider
|
||||
maximumAttribute := f.GetAttributeCount()
|
||||
consideredAttributes := make([]int, r.Attributes)
|
||||
attrCounter := 0
|
||||
for {
|
||||
if len(consideredAttributes) >= r.Attributes {
|
||||
break
|
||||
}
|
||||
selectedAttribute := rand.Intn(maximumAttribute)
|
||||
fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
|
||||
if selectedAttribute != f.ClassIndex {
|
||||
matched := false
|
||||
for _, a := range consideredAttributes {
|
||||
if a == selectedAttribute {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
consideredAttributes = append(consideredAttributes, selectedAttribute)
|
||||
attrCounter++
|
||||
}
|
||||
}
|
||||
|
||||
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
|
||||
}
|
||||
|
||||
// RandomTree builds a decision tree by considering a fixed number
|
||||
// of randomly-chosen attributes at each node
|
||||
type RandomTree struct {
|
||||
base.BaseClassifier
|
||||
Root *DecisionTreeNode
|
||||
Rule *RandomTreeRuleGenerator
|
||||
}
|
||||
|
||||
// NewRandomTree returns a new RandomTree which considers attrs randomly
|
||||
// chosen attributes at each node.
|
||||
func NewRandomTree(attrs int) *RandomTree {
|
||||
return &RandomTree{
|
||||
base.BaseClassifier{},
|
||||
nil,
|
||||
&RandomTreeRuleGenerator{
|
||||
attrs,
|
||||
InformationGainRuleGenerator{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Train builds a RandomTree suitable for prediction
|
||||
func (rt *RandomTree) Fit(from *base.Instances) {
|
||||
rt.Root = InferID3Tree(from, rt.Rule)
|
||||
}
|
||||
|
||||
// Predict returns a set of Instances containing predictions
|
||||
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
|
||||
return rt.Root.Predict(from)
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of this structure
|
||||
func (rt *RandomTree) String() string {
|
||||
return fmt.Sprintf("RandomTree(%s)", rt.Root)
|
||||
}
|
||||
|
||||
// Prune removes nodes from the tree which are detrimental
|
||||
// to determining the accuracy of the test set (with)
|
||||
func (rt *RandomTree) Prune(with *base.Instances) {
|
||||
rt.Root.Prune(with)
|
||||
}
|
250
trees/tree_test.go
Normal file
250
trees/tree_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
package trees
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
base "github.com/sjwhitworth/golearn/base"
|
||||
eval "github.com/sjwhitworth/golearn/evaluation"
|
||||
filters "github.com/sjwhitworth/golearn/filters"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRandomTree(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(inst, r)
|
||||
fmt.Println(root)
|
||||
}
|
||||
|
||||
func TestRandomTreeClassification(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(trainData)
|
||||
filt.Run(testData)
|
||||
fmt.Println(inst)
|
||||
r := new(RandomTreeRuleGenerator)
|
||||
r.Attributes = 2
|
||||
root := InferID3Tree(trainData, r)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestRandomTreeClassification2(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
root := NewRandomTree(2)
|
||||
root.Fit(trainData)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestPruning(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
||||
filt := filters.NewChiMergeFilter(inst, 0.90)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
fmt.Println(testData)
|
||||
filt.Run(testData)
|
||||
filt.Run(trainData)
|
||||
root := NewRandomTree(2)
|
||||
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
|
||||
root.Fit(fittrainData)
|
||||
root.Prune(fittestData)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestInformationGain(testEnv *testing.T) {
|
||||
outlook := make(map[string]map[string]int)
|
||||
outlook["sunny"] = make(map[string]int)
|
||||
outlook["overcast"] = make(map[string]int)
|
||||
outlook["rain"] = make(map[string]int)
|
||||
outlook["sunny"]["play"] = 2
|
||||
outlook["sunny"]["noplay"] = 3
|
||||
outlook["overcast"]["play"] = 4
|
||||
outlook["rain"]["play"] = 3
|
||||
outlook["rain"]["noplay"] = 2
|
||||
|
||||
entropy := getSplitEntropy(outlook)
|
||||
if math.Abs(entropy-0.694) > 0.001 {
|
||||
testEnv.Error(entropy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestID3Inference(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(inst, rule)
|
||||
|
||||
// Verify the tree
|
||||
// First attribute should be "outlook"
|
||||
if root.SplitAttr.GetName() != "outlook" {
|
||||
testEnv.Error(root)
|
||||
}
|
||||
sunnyChild := root.Children["sunny"]
|
||||
overcastChild := root.Children["overcast"]
|
||||
rainyChild := root.Children["rainy"]
|
||||
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||
testEnv.Error(sunnyChild)
|
||||
}
|
||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||
testEnv.Error(rainyChild)
|
||||
}
|
||||
if overcastChild.SplitAttr != nil {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
|
||||
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||
if sunnyLeafHigh.Class != "no" {
|
||||
testEnv.Error(sunnyLeafHigh)
|
||||
}
|
||||
if sunnyLeafNormal.Class != "yes" {
|
||||
testEnv.Error(sunnyLeafNormal)
|
||||
}
|
||||
|
||||
windyLeafFalse := rainyChild.Children["false"]
|
||||
windyLeafTrue := rainyChild.Children["true"]
|
||||
if windyLeafFalse.Class != "yes" {
|
||||
testEnv.Error(windyLeafFalse)
|
||||
}
|
||||
if windyLeafTrue.Class != "no" {
|
||||
testEnv.Error(windyLeafTrue)
|
||||
}
|
||||
|
||||
if overcastChild.Class != "yes" {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
}
|
||||
|
||||
func TestID3Classification(testEnv *testing.T) {
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filt := filters.NewBinningFilter(inst, 10)
|
||||
filt.AddAllNumericAttributes()
|
||||
filt.Build()
|
||||
filt.Run(inst)
|
||||
fmt.Println(inst)
|
||||
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
|
||||
// Build the decision tree
|
||||
rule := new(InformationGainRuleGenerator)
|
||||
root := InferID3Tree(trainData, rule)
|
||||
fmt.Println(root)
|
||||
predictions := root.Predict(testData)
|
||||
fmt.Println(predictions)
|
||||
confusionMat := eval.GetConfusionMatrix(testData, predictions)
|
||||
fmt.Println(confusionMat)
|
||||
fmt.Println(eval.GetMacroPrecision(confusionMat))
|
||||
fmt.Println(eval.GetMacroRecall(confusionMat))
|
||||
fmt.Println(eval.GetSummary(confusionMat))
|
||||
}
|
||||
|
||||
func TestID3(testEnv *testing.T) {
|
||||
|
||||
// Import the "PlayTennis" dataset
|
||||
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Build the decision tree
|
||||
tree := NewID3DecisionTree(0.0)
|
||||
tree.Fit(inst)
|
||||
root := tree.Root
|
||||
|
||||
// Verify the tree
|
||||
// First attribute should be "outlook"
|
||||
if root.SplitAttr.GetName() != "outlook" {
|
||||
testEnv.Error(root)
|
||||
}
|
||||
sunnyChild := root.Children["sunny"]
|
||||
overcastChild := root.Children["overcast"]
|
||||
rainyChild := root.Children["rainy"]
|
||||
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
||||
testEnv.Error(sunnyChild)
|
||||
}
|
||||
if rainyChild.SplitAttr.GetName() != "windy" {
|
||||
testEnv.Error(rainyChild)
|
||||
}
|
||||
if overcastChild.SplitAttr != nil {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
|
||||
sunnyLeafHigh := sunnyChild.Children["high"]
|
||||
sunnyLeafNormal := sunnyChild.Children["normal"]
|
||||
if sunnyLeafHigh.Class != "no" {
|
||||
testEnv.Error(sunnyLeafHigh)
|
||||
}
|
||||
if sunnyLeafNormal.Class != "yes" {
|
||||
testEnv.Error(sunnyLeafNormal)
|
||||
}
|
||||
|
||||
windyLeafFalse := rainyChild.Children["false"]
|
||||
windyLeafTrue := rainyChild.Children["true"]
|
||||
if windyLeafFalse.Class != "yes" {
|
||||
testEnv.Error(windyLeafFalse)
|
||||
}
|
||||
if windyLeafTrue.Class != "no" {
|
||||
testEnv.Error(windyLeafTrue)
|
||||
}
|
||||
|
||||
if overcastChild.Class != "yes" {
|
||||
testEnv.Error(overcastChild)
|
||||
}
|
||||
}
|
@ -1,2 +1,26 @@
|
||||
// Package trees provides a number of tree based ensemble learners.
|
||||
package trees
|
||||
/*
|
||||
|
||||
This package implements decision trees.
|
||||
|
||||
ID3DecisionTree:
|
||||
Builds a decision tree using the ID3 algorithm
|
||||
by picking the Attribute which maximises
|
||||
Information Gain at each node.
|
||||
|
||||
Attributes must be CategoricalAttributes at
|
||||
present, so discretise beforehand (see
|
||||
filters)
|
||||
|
||||
RandomTree:
|
||||
Builds a decision tree using the ID3 algorithm
|
||||
by picking the Attribute amongst those
|
||||
randomly selected that maximises Information
|
||||
Gain
|
||||
|
||||
Attributes must be CategoricalAttributes at
|
||||
present, so discretise beforehand (see
|
||||
filters)
|
||||
|
||||
*/
|
||||
|
||||
package trees
|
Loading…
x
Reference in New Issue
Block a user