1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/trees/id3_test.go
2018-06-16 22:14:18 +08:00

62 lines
1.4 KiB
Go

package trees
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestId3(t *testing.T) {
Convey("Doing a id3 test", t, func() {
var rule DecisionTreeRule
s := rule.String()
So(s, ShouldNotBeNil)
rule.SplitAttr = nil
s = rule.String()
So(s, ShouldNotBeNil)
instances, err := base.ParseCSVToInstances("onerow.csv", true)
So(err, ShouldBeNil)
trainData, _ := base.InstancesTrainTestSplit(instances, 0.6)
gRuleGen := new(GiniCoefficientRuleGenerator)
root := InferID3Tree(trainData, gRuleGen)
s = root.getNestedString(3)
So(s, ShouldNotBeNil)
s = root.String()
So(s, ShouldNotBeNil)
//var proba1 ClassProba
//var proba2 ClassProba
//probas := ClassesProba{proba1, proba2}
_, rc := trainData.Size()
fmt.Println(rc)
id3tree := NewID3DecisionTree(0.1)
So(id3tree, ShouldNotBeNil)
id3tree.Root = root
probas, err := id3tree.PredictProba(trainData)
So(err, ShouldBeNil)
var proba1, proba2 ClassProba
probas = ClassesProba{proba1, proba2}
L := probas.Len()
So(L, ShouldEqual, 2)
probas.Swap(0, 1)
less := probas.Less(0, 1)
So(less, ShouldEqual, false)
data := id3tree.GetMetadata()
So(data, ShouldNotBeNil)
s = id3tree.String()
So(s, ShouldNotBeNil)
_, err = id3tree.Predict(trainData)
So(err, ShouldBeNil)
err = id3tree.Save("tmp")
So(err, ShouldBeNil)
err = id3tree.Load("tmp")
So(err, ShouldNotBeNil) // temp
})
}