2014-05-14 14:00:22 +01:00
|
|
|
package ensemble
|
|
|
|
|
|
|
|
import (
|
2014-08-22 13:16:11 +00:00
|
|
|
"testing"
|
|
|
|
|
2021-01-10 00:56:45 -03:00
|
|
|
"io/ioutil"
|
|
|
|
"os"
|
|
|
|
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/base"
|
2014-08-22 09:33:42 +00:00
|
|
|
"github.com/sjwhitworth/golearn/evaluation"
|
2014-08-22 07:21:24 +00:00
|
|
|
"github.com/sjwhitworth/golearn/filters"
|
2014-08-22 13:16:11 +00:00
|
|
|
. "github.com/smartystreets/goconvey/convey"
|
2014-05-14 14:00:22 +01:00
|
|
|
)
|
|
|
|
|
2014-08-22 13:16:11 +00:00
|
|
|
func TestRandomForest(t *testing.T) {
|
|
|
|
Convey("Given a valid CSV file", t, func() {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
Convey("When Chi-Merge filtering the data", func() {
|
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
|
|
|
|
|
|
|
Convey("Splitting the data into test and training sets", func() {
|
|
|
|
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
|
|
|
|
|
|
|
Convey("Fitting and predicting with a Random Forest", func() {
|
|
|
|
rf := NewRandomForest(10, 3)
|
|
|
|
err = rf.Fit(trainData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
predictions, err := rf.Predict(testData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
Convey("Predictions should be somewhat accurate", func() {
|
|
|
|
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
Convey("Fitting with a Random Forest with too many features compared to the data", func() {
|
|
|
|
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
|
|
|
|
err = rf.Fit(inst)
|
|
|
|
|
|
|
|
Convey("Should return an error", func() {
|
|
|
|
So(err, ShouldNotBeNil)
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
2014-08-20 07:16:11 +00:00
|
|
|
}
|
2017-09-10 21:10:54 +01:00
|
|
|
|
|
|
|
func TestRandomForestSerialization(t *testing.T) {
|
|
|
|
Convey("Given a valid CSV file", t, func() {
|
|
|
|
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
Convey("When Chi-Merge filtering the data", func() {
|
|
|
|
filt := filters.NewChiMergeFilter(inst, 0.90)
|
|
|
|
for _, a := range base.NonClassFloatAttributes(inst) {
|
|
|
|
filt.AddAttribute(a)
|
|
|
|
}
|
|
|
|
filt.Train()
|
|
|
|
instf := base.NewLazilyFilteredInstances(inst, filt)
|
|
|
|
|
|
|
|
Convey("Splitting the data into test and training sets", func() {
|
|
|
|
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
|
|
|
|
|
|
|
|
Convey("Fitting and predicting with a Random Forest", func() {
|
|
|
|
rf := NewRandomForest(10, 3)
|
|
|
|
err = rf.Fit(trainData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
oldPredictions, err := rf.Predict(testData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
|
|
|
|
Convey("Saving the model should work...", func() {
|
|
|
|
f, err := ioutil.TempFile(os.TempDir(), "rf")
|
2019-12-27 12:10:38 +13:00
|
|
|
So(err, ShouldBeNil)
|
2017-09-10 21:10:54 +01:00
|
|
|
err = rf.Save(f.Name())
|
|
|
|
defer func() {
|
|
|
|
f.Close()
|
|
|
|
}()
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
Convey("Loading the model should work...", func() {
|
|
|
|
newRf := NewRandomForest(10, 3)
|
|
|
|
err := newRf.Load(f.Name())
|
|
|
|
So(err, ShouldBeNil)
|
2021-01-10 00:56:45 -03:00
|
|
|
So(len(newRf.Model.Models), ShouldEqual, 10)
|
2017-09-10 21:10:54 +01:00
|
|
|
Convey("Predictions should be the same...", func() {
|
|
|
|
newPredictions, err := newRf.Predict(testData)
|
|
|
|
So(err, ShouldBeNil)
|
|
|
|
So(base.InstancesAreEqual(newPredictions, oldPredictions), ShouldBeTrue)
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|