1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-28 13:48:56 +08:00
golearn/trees/tree_test.go

279 lines
7.4 KiB
Go
Raw Normal View History

package trees
import (
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
2014-08-22 09:33:42 +00:00
"github.com/sjwhitworth/golearn/evaluation"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/filters"
"math"
"testing"
)
func TestRandomTree(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
2014-05-17 16:20:56 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
_ = InferID3Tree(instf, r)
}
func TestRandomTreeClassification(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
2014-05-17 16:20:56 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
2014-08-02 16:22:15 +01:00
root := InferID3Tree(trainDataF, r)
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
_ = evaluation.GetSummary(confusionMat)
}
func TestRandomTreeClassification2(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.4)
2014-05-17 16:20:56 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
root := NewRandomTree(2)
err = root.Fit(trainDataF)
if err != nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
_ = evaluation.GetSummary(confusionMat)
2014-05-17 18:06:01 +01:00
}
func TestPruning(t *testing.T) {
2014-05-17 18:06:01 +01:00
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
2014-05-17 18:06:01 +01:00
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
2014-05-17 18:06:01 +01:00
filt := filters.NewChiMergeFilter(inst, 0.90)
2014-08-02 16:22:15 +01:00
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataF := base.NewLazilyFilteredInstances(trainData, filt)
testDataF := base.NewLazilyFilteredInstances(testData, filt)
2014-05-17 18:06:01 +01:00
root := NewRandomTree(2)
2014-08-02 16:22:15 +01:00
fittrainData, fittestData := base.InstancesTrainTestSplit(trainDataF, 0.6)
err = root.Fit(fittrainData)
if err != nil {
t.Fatalf("Fitting failed: %s", err.Error())
}
root.Prune(fittestData)
predictions, err := root.Predict(testDataF)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
confusionMat, err := evaluation.GetConfusionMatrix(testDataF, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
_ = evaluation.GetSummary(confusionMat)
}
func TestInformationGain(t *testing.T) {
outlook := make(map[string]map[string]int)
outlook["sunny"] = make(map[string]int)
outlook["overcast"] = make(map[string]int)
outlook["rain"] = make(map[string]int)
outlook["sunny"]["play"] = 2
outlook["sunny"]["noplay"] = 3
outlook["overcast"]["play"] = 4
outlook["rain"]["play"] = 3
outlook["rain"]["noplay"] = 2
entropy := getSplitEntropy(outlook)
if math.Abs(entropy-0.694) > 0.001 {
t.Error(entropy)
}
}
2014-05-17 17:28:51 +01:00
func TestID3Inference(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
2014-05-17 17:28:51 +01:00
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
2014-05-17 17:28:51 +01:00
}
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(inst, rule)
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
t.Error(root)
2014-05-17 17:28:51 +01:00
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
t.Error(sunnyChild)
2014-05-17 17:28:51 +01:00
}
if rainyChild.SplitAttr.GetName() != "windy" {
t.Error(rainyChild)
2014-05-17 17:28:51 +01:00
}
if overcastChild.SplitAttr != nil {
t.Error(overcastChild)
2014-05-17 17:28:51 +01:00
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
t.Error(sunnyLeafHigh)
2014-05-17 17:28:51 +01:00
}
if sunnyLeafNormal.Class != "yes" {
t.Error(sunnyLeafNormal)
2014-05-17 17:28:51 +01:00
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
t.Error(windyLeafFalse)
2014-05-17 17:28:51 +01:00
}
if windyLeafTrue.Class != "no" {
t.Error(windyLeafTrue)
2014-05-17 17:28:51 +01:00
}
if overcastChild.Class != "yes" {
t.Error(overcastChild)
2014-05-17 17:28:51 +01:00
}
}
2014-05-17 20:37:19 +01:00
func TestID3Classification(t *testing.T) {
2014-05-17 20:37:19 +01:00
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
2014-05-17 20:37:19 +01:00
}
2014-08-02 16:22:15 +01:00
filt := filters.NewBinningFilter(inst, 10)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
2014-08-02 16:22:15 +01:00
trainData, testData := base.InstancesTrainTestSplit(instf, 0.70)
2014-05-17 20:37:19 +01:00
// Build the decision tree
rule := new(InformationGainRuleGenerator)
root := InferID3Tree(trainData, rule)
predictions, err := root.Predict(testData)
if err != nil {
t.Fatalf("Predicting failed: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
t.Fatalf("Unable to get confusion matrix: %s", err.Error())
}
2014-08-22 09:33:42 +00:00
_ = evaluation.GetSummary(confusionMat)
2014-05-17 20:37:19 +01:00
}
func TestID3(t *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/tennis.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
// Build the decision tree
tree := NewID3DecisionTree(0.0)
tree.Fit(inst)
root := tree.Root
// Verify the tree
// First attribute should be "outlook"
if root.SplitAttr.GetName() != "outlook" {
t.Error(root)
}
sunnyChild := root.Children["sunny"]
overcastChild := root.Children["overcast"]
rainyChild := root.Children["rainy"]
if sunnyChild.SplitAttr.GetName() != "humidity" {
t.Error(sunnyChild)
}
if rainyChild.SplitAttr.GetName() != "windy" {
t.Error(rainyChild)
}
if overcastChild.SplitAttr != nil {
t.Error(overcastChild)
}
sunnyLeafHigh := sunnyChild.Children["high"]
sunnyLeafNormal := sunnyChild.Children["normal"]
if sunnyLeafHigh.Class != "no" {
t.Error(sunnyLeafHigh)
}
if sunnyLeafNormal.Class != "yes" {
t.Error(sunnyLeafNormal)
}
windyLeafFalse := rainyChild.Children["false"]
windyLeafTrue := rainyChild.Children["true"]
if windyLeafFalse.Class != "yes" {
t.Error(windyLeafFalse)
}
if windyLeafTrue.Class != "no" {
t.Error(windyLeafTrue)
}
if overcastChild.Class != "yes" {
t.Error(overcastChild)
}
}