mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
61 lines
1.4 KiB
Go
61 lines
1.4 KiB
Go
package trees
|
|
import (
|
|
"github.com/sjwhitworth/golearn/base"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"testing"
|
|
"fmt"
|
|
)
|
|
|
|
func TestId3(t *testing.T) {
|
|
Convey("Doing a id3 test", t, func(){
|
|
var rule DecisionTreeRule
|
|
s := rule.String()
|
|
So(s, ShouldNotBeNil)
|
|
rule.SplitAttr = nil
|
|
s = rule.String()
|
|
So(s, ShouldNotBeNil)
|
|
|
|
instances, err := base.ParseCSVToInstances("onerow.csv", true)
|
|
So(err, ShouldBeNil)
|
|
|
|
trainData, _ := base.InstancesTrainTestSplit(instances, 0.6)
|
|
gRuleGen := new(GiniCoefficientRuleGenerator)
|
|
root := InferID3Tree(trainData, gRuleGen)
|
|
s = root.getNestedString(3)
|
|
So(s, ShouldNotBeNil)
|
|
s = root.String()
|
|
So(s, ShouldNotBeNil)
|
|
|
|
//var proba1 ClassProba
|
|
//var proba2 ClassProba
|
|
//probas := ClassesProba{proba1, proba2}
|
|
_, rc := trainData.Size()
|
|
fmt.Println(rc)
|
|
id3tree := NewID3DecisionTree(0.1)
|
|
So(id3tree, ShouldNotBeNil)
|
|
id3tree.Root = root
|
|
probas, err := id3tree.PredictProba(trainData)
|
|
So(err, ShouldBeNil)
|
|
var proba1, proba2 ClassProba
|
|
probas = ClassesProba{proba1, proba2}
|
|
L := probas.Len()
|
|
So(L, ShouldEqual, 2)
|
|
probas.Swap(0, 1)
|
|
less := probas.Less(0, 1)
|
|
So(less, ShouldEqual, false)
|
|
|
|
data := id3tree.GetMetadata()
|
|
So(data, ShouldNotBeNil)
|
|
s = id3tree.String()
|
|
So(s, ShouldNotBeNil)
|
|
_, err = id3tree.Predict(trainData)
|
|
So(err, ShouldBeNil)
|
|
|
|
err = id3tree.Save("tmp")
|
|
So(err, ShouldBeNil)
|
|
err = id3tree.Load("tmp")
|
|
So(err, ShouldNotBeNil) // temp
|
|
|
|
})
|
|
}
|