2014-05-14 14:00:22 +01:00
|
|
|
package trees
|
|
|
|
|
|
|
|
import (
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-08-22 09:33:42 +00:00
|
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/filters"
|
2014-05-14 14:00:22 +01:00
|
|
|
"math"
|
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestRandomTree(t *testing.T) {
|
2014-05-14 14:00:22 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-05-17 16:20:56 +01:00
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
|
|
|
|
2014-05-14 14:00:22 +01:00
|
|
|
r := new(RandomTreeRuleGenerator)
|
|
|
|
r.Attributes = 2
|
2014-08-22 07:58:01 +00:00
|
|
|
|
|
|
|
_ = InferID3Tree(instf, r)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestRandomTreeClassification(t *testing.T) {
|
2014-05-14 14:00:22 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
2014-06-06 20:30:24 +02:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-05-17 16:20:56 +01:00
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
|
|
|
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
|
|
|
|
2014-05-14 14:00:22 +01:00
|
|
|
r := new(RandomTreeRuleGenerator)
|
|
|
|
r.Attributes = 2
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-02 16:22:15 +01:00
|
|
|
root := InferID3Tree(trainDataF, r)
|
2014-08-20 07:16:11 +00:00
|
|
|
predictions, err := root.Predict(testDataF)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Predicting failed: %s", err.Error())
|
|
|
|
}
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-22 09:33:42 +00:00
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
2014-08-22 08:52:37 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
|
|
|
}
|
2014-08-22 09:33:42 +00:00
|
|
|
_ = evaluation.GetSummary(confusionMat)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestRandomTreeClassification2(t *testing.T) {
|
2014-05-14 14:00:22 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
2014-06-06 20:30:24 +02:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-05-17 16:20:56 +01:00
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
|
|
|
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
|
|
|
|
2014-05-14 14:00:22 +01:00
|
|
|
root := NewRandomTree(2)
|
2014-08-20 07:16:11 +00:00
|
|
|
err = root.Fit(trainDataF)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Fitting failed: %s", err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
predictions, err := root.Predict(testDataF)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Predicting failed: %s", err.Error())
|
|
|
|
}
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-22 09:33:42 +00:00
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
2014-08-22 08:52:37 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
|
|
|
}
|
2014-08-22 09:33:42 +00:00
|
|
|
_ = evaluation.GetSummary(confusionMat)
|
2014-05-17 18:06:01 +01:00
|
|
|
}
|
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestPruning(t *testing.T) {
|
2014-05-17 18:06:01 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-17 18:06:01 +01:00
|
|
|
}
|
2014-06-06 20:30:24 +02:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-05-17 18:06:01 +01:00
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
2014-08-02 16:22:15 +01:00
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
|
|
|
|
testDataF := base.NewLazilyFilteredInstances(testData, filt)
|
|
|
|
|
2014-05-17 18:06:01 +01:00
|
|
|
root := NewRandomTree(2)
|
2014-08-02 16:22:15 +01:00
|
|
|
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
|
2014-08-20 07:16:11 +00:00
|
|
|
|
|
|
|
err = root.Fit(fittrainData)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Fitting failed: %s", err.Error())
|
|
|
|
}
|
|
|
|
|
2014-06-06 20:30:24 +02:00
|
|
|
root.Prune(fittestData)
|
2014-08-20 07:16:11 +00:00
|
|
|
predictions, err := root.Predict(testDataF)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Predicting failed: %s", err.Error())
|
|
|
|
}
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-22 09:33:42 +00:00
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
|
2014-08-22 08:52:37 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
|
|
|
}
|
2014-08-22 09:33:42 +00:00
|
|
|
_ = evaluation.GetSummary(confusionMat)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestInformationGain(t *testing.T) {
|
2014-05-14 14:00:22 +01:00
|
|
|
outlook := make(map[string]map[string]int)
|
|
|
|
outlook["sunny"] = make(map[string]int)
|
|
|
|
outlook["overcast"] = make(map[string]int)
|
|
|
|
outlook["rain"] = make(map[string]int)
|
|
|
|
outlook["sunny"]["play"] = 2
|
|
|
|
outlook["sunny"]["noplay"] = 3
|
|
|
|
outlook["overcast"]["play"] = 4
|
|
|
|
outlook["rain"]["play"] = 3
|
|
|
|
outlook["rain"]["noplay"] = 2
|
|
|
|
|
|
|
|
entropy := getSplitEntropy(outlook)
|
|
|
|
if math.Abs(entropy-0.694) > 0.001 {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(entropy)
|
2014-05-14 14:00:22 +01:00
|
|
|
}
|
|
|
|
}
|
2014-05-17 17:28:51 +01:00
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestID3Inference(t *testing.T) {
|
2014-05-19 12:42:03 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
2014-05-17 17:28:51 +01:00
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Build the decision tree
|
|
|
|
rule := new(InformationGainRuleGenerator)
|
|
|
|
root := InferID3Tree(inst, rule)
|
|
|
|
|
|
|
|
// Verify the tree
|
|
|
|
// First attribute should be "outlook"
|
|
|
|
if root.SplitAttr.GetName() != "outlook" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(root)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
sunnyChild := root.Children["sunny"]
|
|
|
|
overcastChild := root.Children["overcast"]
|
|
|
|
rainyChild := root.Children["rainy"]
|
|
|
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyChild)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(rainyChild)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
if overcastChild.SplitAttr != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(overcastChild)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
|
|
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
|
|
|
if sunnyLeafHigh.Class != "no" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyLeafHigh)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
if sunnyLeafNormal.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyLeafNormal)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
windyLeafFalse := rainyChild.Children["false"]
|
|
|
|
windyLeafTrue := rainyChild.Children["true"]
|
|
|
|
if windyLeafFalse.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(windyLeafFalse)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
if windyLeafTrue.Class != "no" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(windyLeafTrue)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if overcastChild.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(overcastChild)
|
2014-05-17 17:28:51 +01:00
|
|
|
}
|
|
|
|
}
|
2014-05-17 20:37:19 +01:00
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestID3Classification(t *testing.T) {
|
2014-05-17 20:37:19 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-17 20:37:19 +01:00
|
|
|
}
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-02 16:22:15 +01:00
|
|
|
filt := filters.NewBinningFilter(inst, 10)
|
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-02 16:22:15 +01:00
|
|
|
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
|
|
|
|
|
2014-05-17 20:37:19 +01:00
|
|
|
// Build the decision tree
|
|
|
|
rule := new(InformationGainRuleGenerator)
|
2014-06-06 20:30:24 +02:00
|
|
|
root := InferID3Tree(trainData, rule)
|
2014-08-22 07:58:01 +00:00
|
|
|
|
2014-08-20 07:16:11 +00:00
|
|
|
predictions, err := root.Predict(testData)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Predicting failed: %s", err.Error())
|
|
|
|
}
|
|
|
|
|
2014-08-22 09:33:42 +00:00
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
2014-08-22 08:52:37 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
|
|
|
|
}
|
2014-08-22 09:33:42 +00:00
|
|
|
_ = evaluation.GetSummary(confusionMat)
|
2014-05-17 20:37:19 +01:00
|
|
|
}
|
2014-05-17 21:45:26 +01:00
|
|
|
|
2014-08-22 08:13:19 +00:00
|
|
|
func TestID3(t *testing.T) {
|
2014-05-19 12:42:03 +01:00
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
|
2014-05-17 21:45:26 +01:00
|
|
|
if err != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Build the decision tree
|
|
|
|
tree := NewID3DecisionTree(0.0)
|
|
|
|
tree.Fit(inst)
|
|
|
|
root := tree.Root
|
|
|
|
|
|
|
|
// Verify the tree
|
|
|
|
// First attribute should be "outlook"
|
|
|
|
if root.SplitAttr.GetName() != "outlook" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(root)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
sunnyChild := root.Children["sunny"]
|
|
|
|
overcastChild := root.Children["overcast"]
|
|
|
|
rainyChild := root.Children["rainy"]
|
|
|
|
if sunnyChild.SplitAttr.GetName() != "humidity" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyChild)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
if rainyChild.SplitAttr.GetName() != "windy" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(rainyChild)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
if overcastChild.SplitAttr != nil {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(overcastChild)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
sunnyLeafHigh := sunnyChild.Children["high"]
|
|
|
|
sunnyLeafNormal := sunnyChild.Children["normal"]
|
|
|
|
if sunnyLeafHigh.Class != "no" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyLeafHigh)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
if sunnyLeafNormal.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(sunnyLeafNormal)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
windyLeafFalse := rainyChild.Children["false"]
|
|
|
|
windyLeafTrue := rainyChild.Children["true"]
|
|
|
|
if windyLeafFalse.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(windyLeafFalse)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
if windyLeafTrue.Class != "no" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(windyLeafTrue)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if overcastChild.Class != "yes" {
|
2014-08-22 08:13:19 +00:00
|
|
|
t.Error(overcastChild)
|
2014-05-17 21:45:26 +01:00
|
|
|
}
|
|
|
|
}
|