1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-30 13:48:57 +08:00

Merge branch 'master' into feature/naive

This commit is contained in:
Thiago Cardoso 2014-06-07 23:46:07 -03:00
commit 0cf6d258e6
20 changed files with 1259 additions and 27 deletions

View File

@ -4,6 +4,9 @@ go:
- 1.2
- release
- tip
before_install:
- sudo apt-get update -qq
- sudo apt-get install -qq libatlas-base-dev
install:
- go get github.com/smartystreets/goconvey/convey
- go get -v ./...

View File

@ -2,7 +2,8 @@ GoLearn
=======
<img src="http://talks.golang.org/2013/advconc/gopherhat.jpg" width=125><br>
[![GoDoc](https://godoc.org/github.com/sjwhitworth/golearn?status.png)](https://godoc.org/github.com/sjwhitworth/golearn)<br>
[![GoDoc](https://godoc.org/github.com/sjwhitworth/golearn?status.png)](https://godoc.org/github.com/sjwhitworth/golearn)
[![Build Status](https://travis-ci.org/sjwhitworth/golearn.png?branch=master)](https://travis-ci.org/sjwhitworth/golearn)<br>
GoLearn is a 'batteries included' machine learning library for Go. **Simplicity**, paired with customisability, is the goal.
We are in active development, and would love comments from users out in the wild. Drop us a line on Twitter.

View File

@ -219,7 +219,7 @@ func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) float64 {
// Returns a string containing the list of human-readable values this
// CategoricalAttribute can take.
func (Attr *CategoricalAttribute) String() string {
return fmt.Sprintf("CategoricalAttribute(%s)", Attr.values)
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
}
// GetStringFromSysVal returns a human-readable value from the given system-representation

View File

@ -5,8 +5,9 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/gonum/matrix/mat64"
"math/rand"
"github.com/gonum/matrix/mat64"
)
// SortDirection specifies sorting direction...
@ -171,10 +172,11 @@ func NewInstancesFromDense(attrs []Attribute, rows int, mat *mat64.Dense) *Insta
//
// IMPORTANT: this function is only meaningful when prop is between 0.0 and 1.0.
// Using any other values may result in odd behaviour.
func InstancesTrainTestSplit(src *Instances, prop float64) [2](*Instances) {
func InstancesTrainTestSplit(src *Instances, prop float64) (*Instances, *Instances) {
trainingRows := make([]int, 0)
testingRows := make([]int, 0)
numAttrs := len(src.attributes)
src.Shuffle()
for i := 0; i < src.Rows; i++ {
trainOrTest := rand.Intn(101)
if trainOrTest > int(100*prop) {
@ -196,10 +198,10 @@ func InstancesTrainTestSplit(src *Instances, prop float64) [2](*Instances) {
rawTestMatrix.SetRow(i, rowDat)
}
var ret [2]*Instances
ret[0] = NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
ret[1] = NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
return ret
trainingRet := NewInstancesFromDense(src.attributes, len(trainingRows), rawTrainMatrix)
testRet := NewInstancesFromDense(src.attributes, len(testingRows), rawTestMatrix)
return trainingRet, testRet
}
// CountAttrValues returns the distribution of values of a given
@ -273,6 +275,33 @@ func (inst *Instances) DecomposeOnAttributeValues(at Attribute) map[string]*Inst
return ret
}
func (inst *Instances) GetClassDistributionAfterSplit(at Attribute) map[string]map[string]int {
ret := make(map[string]map[string]int)
// Find the attribute we're decomposing on
attrIndex := inst.GetAttrIndex(at)
if attrIndex == -1 {
panic("Invalid attribute index")
}
// Get the class index
classAttr := inst.GetAttr(inst.ClassIndex)
for i := 0; i < inst.Rows; i++ {
splitVar := at.GetStringFromSysVal(inst.Get(i, attrIndex))
classVar := classAttr.GetStringFromSysVal(inst.Get(i, inst.ClassIndex))
if _, ok := ret[splitVar]; !ok {
ret[splitVar] = make(map[string]int)
i--
continue
}
ret[splitVar][classVar]++
}
return ret
}
// Get returns the system representation (float64) of the value
// stored at the given row and col coordinate.
func (inst *Instances) Get(row int, col int) float64 {
@ -308,6 +337,20 @@ func (inst *Instances) GetClass(row int) string {
return attr.GetStringFromSysVal(val)
}
// GetClassDist returns a map containing the count of each
// class type (indexed by the class' string representation)
func (inst *Instances) GetClassDistribution() map[string]int {
ret := make(map[string]int)
attr := inst.GetAttr(inst.ClassIndex)
for i := 0; i < inst.Rows; i++ {
val := inst.Get(i, inst.ClassIndex)
cls := attr.GetStringFromSysVal(val)
ret[cls]++
}
return ret
}
func (Inst *Instances) GetClassAttrPtr() *Attribute {
attr := Inst.GetAttr(Inst.ClassIndex)
return &attr
@ -465,7 +508,7 @@ func (inst *Instances) GeneratePredictionVector() *Instances {
// Shuffle randomizes the row order in place
func (inst *Instances) Shuffle() {
for i := 0; i < inst.Rows; i++ {
j := rand.Intn(inst.Rows)
j := rand.Intn(i + 1)
inst.swapRows(i, j)
}
}

13
ensemble/ensemble.go Normal file
View File

@ -0,0 +1,13 @@
/*
Ensemble contains classifiers which combine other classifiers.
RandomForest:
Generates ForestSize bagged decision trees (currently ID3-based)
each considering a fixed number of random features.
Built on meta.Bagging
*/
package ensemble

50
ensemble/randomforest.go Normal file
View File

@ -0,0 +1,50 @@
package ensemble
import (
base "github.com/sjwhitworth/golearn/base"
meta "github.com/sjwhitworth/golearn/meta"
trees "github.com/sjwhitworth/golearn/trees"
"fmt"
)
// RandomForest classifies instances using an ensemble
// of bagged random decision trees
type RandomForest struct {
base.BaseClassifier
ForestSize int
Features int
Model *meta.BaggedModel
}
// NewRandomForests generates and return a new random forests
// forestSize controls the number of trees that get built
// features controls the number of features used to build each tree
func NewRandomForest(forestSize int, features int) *RandomForest {
ret := &RandomForest{
base.BaseClassifier{},
forestSize,
features,
nil,
}
return ret
}
// Train builds the RandomForest on the specified instances
func (f *RandomForest) Fit(on *base.Instances) {
f.Model = new(meta.BaggedModel)
f.Model.RandomFeatures = f.Features
for i := 0; i < f.ForestSize; i++ {
tree := trees.NewID3DecisionTree(0.00)
f.Model.AddModel(tree)
}
f.Model.Fit(on)
}
// Predict generates predictions from a trained RandomForest
func (f *RandomForest) Predict(with *base.Instances) *base.Instances {
return f.Model.Predict(with)
}
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}

View File

@ -0,0 +1,29 @@
package ensemble
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
"testing"
)
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}

View File

@ -0,0 +1,15 @@
outlook,temp,humidity,windy,play
sunny,hot,high,false,no
sunny,hot,high,true,no
overcast,hot,high,false,yes
rainy,mild,high,false,yes
rainy,cool,normal,false,yes
rainy,cool,normal,true,no
overcast,cool,normal,true,yes
sunny,mild,high,false,no
sunny,cool,normal,false,yes
rainy,mild,normal,false,yes
sunny,mild,normal,true,yes
overcast,mild,high,true,yes
overcast,hot,normal,false,yes
rainy,mild,high,true,no
1 outlook temp humidity windy play
2 sunny hot high false no
3 sunny hot high true no
4 overcast hot high false yes
5 rainy mild high false yes
6 rainy cool normal false yes
7 rainy cool normal true no
8 overcast cool normal true yes
9 sunny mild high false no
10 sunny cool normal false yes
11 rainy mild normal false yes
12 sunny mild normal true yes
13 overcast mild high true yes
14 overcast hot normal false yes
15 rainy mild high true no

View File

@ -17,9 +17,7 @@ func main() {
cls := knn.NewKnnClassifier("euclidean", 2)
//Do a training-test split
trainTest := base.InstancesTrainTestSplit(rawData, 0.50)
trainData := trainTest[0]
testData := trainTest[1]
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
cls.Fit(trainData)
//Calculates the Euclidean distance and returns the most popular label

75
examples/trees/trees.go Normal file
View File

@ -0,0 +1,75 @@
// Demonstrates decision tree classification
package main
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
ensemble "github.com/sjwhitworth/golearn/ensemble"
trees "github.com/sjwhitworth/golearn/trees"
"math/rand"
"time"
)
func main () {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(iris, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(trainData)
// Generate predictions
predictions := tree.Predict(testData)
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(testData)
predictions = tree.Predict(testData)
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(trainData)
predictions = tree.Predict(testData)
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
}

View File

@ -43,6 +43,21 @@ func (c *ChiMergeFilter) Build() {
}
}
// AddAllNumericAttributes adds every suitable attribute
// to the ChiMergeFilter for discretisation
func (b *ChiMergeFilter) AddAllNumericAttributes() {
for i := 0; i < b.Instances.Cols; i++ {
if i == b.Instances.ClassIndex {
continue
}
attr := b.Instances.GetAttr(i)
if attr.GetType() != base.Float64Type {
continue
}
b.Attributes = append(b.Attributes, i)
}
}
// Run discretises the set of Instances `on'
//
// IMPORTANT: ChiMergeFilter discretises in place.

190
meta/bagging.go Normal file
View File

@ -0,0 +1,190 @@
package meta
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math/rand"
"runtime"
"strings"
"sync"
)
// BaggedModel trains base.Classifiers on subsets of the original
// Instances and combine the results through voting
type BaggedModel struct {
base.BaseClassifier
Models []base.Classifier
RandomFeatures int
lock sync.Mutex
selectedAttributes map[int][]base.Attribute
}
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
// the provided base.Instances.
func (b *BaggedModel) generateTrainingAttrs(model int, from *base.Instances) []base.Attribute {
ret := make([]base.Attribute, 0)
if b.RandomFeatures == 0 {
for j := 0; j < from.Cols; j++ {
attr := from.GetAttr(j)
ret = append(ret, attr)
}
} else {
for {
if len(ret) >= b.RandomFeatures {
break
}
attrIndex := rand.Intn(from.Cols)
if attrIndex == from.ClassIndex {
continue
}
attr := from.GetAttr(attrIndex)
matched := false
for _, a := range ret {
if a.Equals(attr) {
matched = true
break
}
}
if !matched {
ret = append(ret, attr)
}
}
}
ret = append(ret, from.GetClassAttr())
b.lock.Lock()
b.selectedAttributes[model] = ret
b.lock.Unlock()
return ret
}
// generatePredictionInstances returns a modified version of the
// requested base.Instances with only the base.Attributes selected
// for training the model.
func (b *BaggedModel) generatePredictionInstances(model int, from *base.Instances) *base.Instances {
selected := b.selectedAttributes[model]
return from.SelectAttributes(selected)
}
// generateTrainingInstances generates RandomFeatures number of
// attributes and returns a modified version of base.Instances
// for training the model
func (b *BaggedModel) generateTrainingInstances(model int, from *base.Instances) *base.Instances {
insts := from.SampleWithReplacement(from.Rows)
selected := b.generateTrainingAttrs(model, from)
return insts.SelectAttributes(selected)
}
// AddModel adds a base.Classifier to the current model
func (b *BaggedModel) AddModel(m base.Classifier) {
b.Models = append(b.Models, m)
}
// Train generates and trains each model on a randomised subset of
// Instances.
func (b *BaggedModel) Fit(from *base.Instances) {
var wait sync.WaitGroup
b.selectedAttributes = make(map[int][]base.Attribute)
for i, m := range b.Models {
wait.Add(1)
go func(c base.Classifier, f *base.Instances, model int) {
l := b.generateTrainingInstances(model, f)
c.Fit(l)
wait.Done()
}(m, from, i)
}
wait.Wait()
}
// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from *base.Instances) *base.Instances {
n := runtime.NumCPU()
// Channel to receive the results as they come in
votes := make(chan *base.Instances, n)
// Count the votes for each class
voting := make(map[int](map[string]int))
// Create a goroutine to collect the votes
var votingwait sync.WaitGroup
votingwait.Add(1)
go func() {
for {
incoming, ok := <-votes
if ok {
// Step through each prediction
for j := 0; j < incoming.Rows; j++ {
// Check if we've seen this class before...
if _, ok := voting[j]; !ok {
// If we haven't, create an entry
voting[j] = make(map[string]int)
// Continue on the current row
j--
continue
}
voting[j][incoming.GetClass(j)]++
}
} else {
votingwait.Done()
break
}
}
}()
// Create workers to process the predictions
processpipe := make(chan int, n)
var processwait sync.WaitGroup
for i := 0; i < n; i++ {
processwait.Add(1)
go func() {
for {
if i, ok := <-processpipe; ok {
c := b.Models[i]
l := b.generatePredictionInstances(i, from)
votes <- c.Predict(l)
} else {
processwait.Done()
break
}
}
}()
}
// Send all the models to the workers for prediction
for i, _ := range b.Models {
processpipe <- i
}
close(processpipe) // Finished sending models to be predicted
processwait.Wait() // Predictors all finished processing
close(votes) // Close the vote channel and allow it to drain
votingwait.Wait() // All the votes are in
// Generate the overall consensus
ret := from.GeneratePredictionVector()
for i := range voting {
maxClass := ""
maxCount := 0
// Find the most popular class
for c := range voting[i] {
votes := voting[i][c]
if votes > maxCount {
maxClass = c
maxCount = votes
}
}
ret.SetAttrStr(i, 0, maxClass)
}
return ret
}
// String returns a human-readable representation of the
// BaggedModel and everything it contains
func (b *BaggedModel) String() string {
children := make([]string, 0)
for i, m := range b.Models {
children = append(children, fmt.Sprintf("%d: %s", i, m))
}
return fmt.Sprintf("BaggedModel(\n%s)", strings.Join(children, "\n\t"))
}

83
meta/bagging_test.go Normal file
View File

@ -0,0 +1,83 @@
package meta
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
trees "github.com/sjwhitworth/golearn/trees"
"math/rand"
"testing"
"time"
)
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
rf.Fit(inst)
}
}
func BenchmarkBaggingRandomForestPredict(testEnv *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(inst)
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(inst)
}
}
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainData)
fmt.Println(rf)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}

View File

@ -1,12 +0,0 @@
package trees
import base "github.com/sjwhitworth/golearn/base"
type DecisionTree struct {
base.BaseEstimator
}
type Branch struct {
LeftBranch Branch
RightBranch Branch
}

100
trees/entropy.go Normal file
View File

@ -0,0 +1,100 @@
package trees
import (
base "github.com/sjwhitworth/golearn/base"
"math"
)
//
// Information gain rule generator
//
type InformationGainRuleGenerator struct {
}
// GetSplitAttribute returns the non-class Attribute which maximises the
// information gain.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
allAttributes := make([]int, 0)
for i := 0; i < f.Cols; i++ {
if i != f.ClassIndex {
allAttributes = append(allAttributes, i)
}
}
return r.GetSplitAttributeFromSelection(allAttributes, f)
}
// GetSplitAttribute from selection returns the class Attribute which maximises
// the information gain amongst consideredAttributes
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRuleGenerator) GetSplitAttributeFromSelection(consideredAttributes []int, f *base.Instances) base.Attribute {
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxGain := math.Inf(-1)
selectedAttribute := -1
// Compute the base entropy
classDist := f.GetClassDistribution()
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
proposedClassDist := f.GetClassDistributionAfterSplit(f.GetAttr(s))
localEntropy := getSplitEntropy(proposedClassDist)
informationGain := baseEntropy - localEntropy
if informationGain > maxGain {
maxGain = informationGain
selectedAttribute = s
}
}
// Pick the one which maximises IG
return f.GetAttr(selectedAttribute)
}
//
// Entropy functions
//
// getSplitEntropy determines the entropy of the target
// class distribution after splitting on an base.Attribute
func getSplitEntropy(s map[string]map[string]int) float64 {
ret := 0.0
count := 0
for a := range s {
for c := range s[a] {
count += s[a][c]
}
}
for a := range s {
total := 0.0
for c := range s[a] {
total += float64(s[a][c])
}
for c := range s[a] {
ret -= float64(s[a][c]) / float64(count) * math.Log(float64(s[a][c])/float64(count)) / math.Log(2)
}
ret += total / float64(count) * math.Log(total/float64(count)) / math.Log(2)
}
return ret
}
// getBaseEntropy determines the entropy of the target
// class distribution before splitting on an base.Attribute
func getBaseEntropy(s map[string]int) float64 {
ret := 0.0
count := 0
for k := range s {
count += s[k]
}
for k := range s {
ret -= float64(s[k]) / float64(count) * math.Log(float64(s[k])/float64(count)) / math.Log(2)
}
return ret
}

267
trees/id3.go Normal file
View File

@ -0,0 +1,267 @@
package trees
import (
"bytes"
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
"sort"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not
type NodeType int
const (
// LeafNode means there are no children
LeafNode NodeType = 1
// RuleNode means we should look at the next attribute value
RuleNode NodeType = 2
)
// RuleGenerator implementations analyse instances and determine
// the best value to split on
type RuleGenerator interface {
GenerateSplitAttribute(*base.Instances) base.Attribute
}
// DecisionTreeNode represents a given portion of a decision tree
type DecisionTreeNode struct {
Type NodeType
Children map[string]*DecisionTreeNode
SplitAttr base.Attribute
ClassDist map[string]int
Class string
ClassAttr *base.Attribute
}
// InferID3Tree builds a decision tree using a RuleGenerator
// from a set of Instances (implements the ID3 algorithm)
func InferID3Tree(from *base.Instances, with RuleGenerator) *DecisionTreeNode {
// Count the number of classes at this node
classes := from.CountClassValues()
// If there's only one class, return a DecisionTreeLeaf with
// the only class available
if len(classes) == 1 {
maxClass := ""
for i := range classes {
maxClass = i
}
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
return ret
}
// Only have the class attribute
maxVal := 0
maxClass := ""
for i := range classes {
if classes[i] > maxVal {
maxClass = i
maxVal = classes[i]
}
}
// If there are no more Attributes left to split on,
// return a DecisionTreeLeaf with the majority class
if from.GetAttributeCount() == 2 {
ret := &DecisionTreeNode{
LeafNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
return ret
}
ret := &DecisionTreeNode{
RuleNode,
nil,
nil,
classes,
maxClass,
from.GetClassAttrPtr(),
}
// Generate a return structure
// Generate the splitting attribute
splitOnAttribute := with.GenerateSplitAttribute(from)
if splitOnAttribute == nil {
// Can't determine, just return what we have
return ret
}
// Split the attributes based on this attribute's value
splitInstances := from.DecomposeOnAttributeValues(splitOnAttribute)
// Create new children from these attributes
ret.Children = make(map[string]*DecisionTreeNode)
for k := range splitInstances {
newInstances := splitInstances[k]
ret.Children[k] = InferID3Tree(newInstances, with)
}
ret.SplitAttr = splitOnAttribute
return ret
}
// getNestedString returns the contents of node d
// prefixed by level number of tags (also prints children)
func (d *DecisionTreeNode) getNestedString(level int) string {
buf := bytes.NewBuffer(nil)
tmp := bytes.NewBuffer(nil)
for i := 0; i < level; i++ {
tmp.WriteString("\t")
}
buf.WriteString(tmp.String())
if d.Children == nil {
buf.WriteString(fmt.Sprintf("Leaf(%s)", d.Class))
} else {
buf.WriteString(fmt.Sprintf("Rule(%s)", d.SplitAttr.GetName()))
keys := make([]string, 0)
for k := range d.Children {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
buf.WriteString("\n")
buf.WriteString(tmp.String())
buf.WriteString("\t")
buf.WriteString(k)
buf.WriteString("\n")
buf.WriteString(d.Children[k].getNestedString(level + 1))
}
}
return buf.String()
}
// String returns a human-readable representation of a given node
// and it's children
func (d *DecisionTreeNode) String() string {
return d.getNestedString(0)
}
// computeAccuracy is a helper method for Prune()
func computeAccuracy(predictions *base.Instances, from *base.Instances) float64 {
cf := eval.GetConfusionMatrix(from, predictions)
return eval.GetAccuracy(cf)
}
// Prune eliminates branches which hurt accuracy
func (d *DecisionTreeNode) Prune(using *base.Instances) {
// If you're a leaf, you're already pruned
if d.Children == nil {
return
} else {
if d.SplitAttr == nil {
return
}
// Recursively prune children of this node
sub := using.DecomposeOnAttributeValues(d.SplitAttr)
for k := range d.Children {
if sub[k] == nil {
continue
}
d.Children[k].Prune(sub[k])
}
}
// Get a baseline accuracy
baselineAccuracy := computeAccuracy(d.Predict(using), using)
// Speculatively remove the children and re-evaluate
tmpChildren := d.Children
d.Children = nil
newAccuracy := computeAccuracy(d.Predict(using), using)
// Keep the children removed if better, else restore
if newAccuracy < baselineAccuracy {
d.Children = tmpChildren
}
}
// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what *base.Instances) *base.Instances {
outputAttrs := make([]base.Attribute, 1)
outputAttrs[0] = what.GetClassAttr()
predictions := base.NewInstances(outputAttrs, what.Rows)
for i := 0; i < what.Rows; i++ {
cur := d
for {
if cur.Children == nil {
predictions.SetAttrStr(i, 0, cur.Class)
break
} else {
at := cur.SplitAttr
j := what.GetAttrIndex(at)
if j == -1 {
predictions.SetAttrStr(i, 0, cur.Class)
break
}
classVar := at.GetStringFromSysVal(what.Get(i, j))
if next, ok := cur.Children[classVar]; ok {
cur = next
} else {
var bestChild string
for c := range cur.Children {
bestChild = c
if c > classVar {
break
}
}
cur = cur.Children[bestChild]
}
}
}
}
return predictions
}
//
// ID3 Tree type
//
// ID3DecisionTree represents an ID3-based decision tree
// using the Information Gain metric to select which attributes
// to split on at each node.
type ID3DecisionTree struct {
base.BaseClassifier
Root *DecisionTreeNode
PruneSplit float64
}
// Returns a new ID3DecisionTree with the specified test-prune
// ratio. Of the ratio is less than 0.001, the tree isn't pruned
func NewID3DecisionTree(prune float64) *ID3DecisionTree {
return &ID3DecisionTree{
base.BaseClassifier{},
nil,
prune,
}
}
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on *base.Instances) {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(trainData, rule)
t.Root.Prune(testData)
} else {
t.Root = InferID3Tree(on, rule)
}
}
// Predict outputs predictions from the ID3 decision tree
func (t *ID3DecisionTree) Predict(what *base.Instances) *base.Instances {
return t.Root.Predict(what)
}
// String returns a human-readable version of this ID3 tree
func (t *ID3DecisionTree) String() string {
return fmt.Sprintf("ID3DecisionTree(%s\n)", t.Root)
}

88
trees/random.go Normal file
View File

@ -0,0 +1,88 @@
package trees
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
"math/rand"
)
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
type RandomTreeRuleGenerator struct {
Attributes int
internalRule InformationGainRuleGenerator
}
// GenerateSplitAttribute returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitAttribute(f *base.Instances) base.Attribute {
// First step is to generate the random attributes that we'll consider
maximumAttribute := f.GetAttributeCount()
consideredAttributes := make([]int, r.Attributes)
attrCounter := 0
for {
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttribute := rand.Intn(maximumAttribute)
fmt.Println(selectedAttribute, attrCounter, consideredAttributes, len(consideredAttributes))
if selectedAttribute != f.ClassIndex {
matched := false
for _, a := range consideredAttributes {
if a == selectedAttribute {
matched = true
break
}
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
}
return r.internalRule.GetSplitAttributeFromSelection(consideredAttributes, f)
}
// RandomTree builds a decision tree by considering a fixed number
// of randomly-chosen attributes at each node
type RandomTree struct {
base.BaseClassifier
Root *DecisionTreeNode
Rule *RandomTreeRuleGenerator
}
// NewRandomTree returns a new RandomTree which considers attrs randomly
// chosen attributes at each node.
func NewRandomTree(attrs int) *RandomTree {
return &RandomTree{
base.BaseClassifier{},
nil,
&RandomTreeRuleGenerator{
attrs,
InformationGainRuleGenerator{},
},
}
}
// Train builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from *base.Instances) {
rt.Root = InferID3Tree(from, rt.Rule)
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from *base.Instances) *base.Instances {
return rt.Root.Predict(from)
}
// String returns a human-readable representation of this structure
func (rt *RandomTree) String() string {
return fmt.Sprintf("RandomTree(%s)", rt.Root)
}
// Prune removes nodes from the tree which are detrimental
// to determining the accuracy of the test set (with)
func (rt *RandomTree) Prune(with *base.Instances) {
rt.Root.Prune(with)
}

250
trees/tree_test.go Normal file
View File

@ -0,0 +1,250 @@
package trees
import (
"fmt"
base "github.com/sjwhitworth/golearn/base"
eval "github.com/sjwhitworth/golearn/evaluation"
filters "github.com/sjwhitworth/golearn/filters"
"math"
"testing"
)
func TestRandomTree(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(inst, r)
fmt.Println(root)
}
func TestRandomTreeClassification(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(trainData)
filt.Run(testData)
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(trainData, r)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
func TestRandomTreeClassification2(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
root := NewRandomTree(2)
root.Fit(trainData)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
func TestPruning(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
func TestInformationGain(testEnv *testing.T) {
outlook := make(map[string]map[string]int)
outlook["sunny"] = make(map[string]int)
outlook["overcast"] = make(map[string]int)
outlook["rain"] = make(map[string]int)
outlook["sunny"]["play"] = 2
outlook["sunny"]["noplay"] = 3
outlook["overcast"]["play"] = 4
outlook["rain"]["play"] = 3
outlook["rain"]["noplay"] = 2
entropy := getSplitEntropy(outlook)
if math.Abs(entropy-0.694) > 0.001 {
testEnv.Error(entropy)
}
}
func TestID3Inference(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
panic(err)
}
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(inst, rule)
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
testEnv.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
testEnv.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
testEnv.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
testEnv.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
testEnv.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
testEnv.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
testEnv.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
testEnv.Error(overcastChild)
}
}
func TestID3Classification(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewBinningFilter(inst, 10)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
trainData, testData := base.InstancesTrainTestSplit(inst, 0.70)
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, rule)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
func TestID3(testEnv *testing.T) {
// Import the "PlayTennis" dataset
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
panic(err)
}
// Build the decision tree
tree := NewID3DecisionTree(0.0)
tree.Fit(inst)
root := tree.Root
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
testEnv.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
testEnv.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
testEnv.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
testEnv.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
testEnv.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
testEnv.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
testEnv.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
testEnv.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
testEnv.Error(overcastChild)
}
}

View File

@ -1,2 +1,26 @@
// Package trees provides a number of tree based ensemble learners.
/*
This package implements decision trees.
ID3DecisionTree:
Builds a decision tree using the ID3 algorithm
by picking the Attribute which maximises
Information Gain at each node.
Attributes must be CategoricalAttributes at
present, so discretise beforehand (see
filters)
RandomTree:
Builds a decision tree using the ID3 algorithm
by picking the Attribute amongst those
randomly selected that maximises Information
Gain
Attributes must be CategoricalAttributes at
present, so discretise beforehand (see
filters)
*/
package trees