1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00
golearn/ensemble/randomforest_test.go

109 lines
3.1 KiB
Go
Raw Normal View History

package ensemble
import (
"testing"
2021-01-10 00:56:45 -03:00
"io/ioutil"
"os"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/base"
2014-08-22 09:33:42 +00:00
"github.com/sjwhitworth/golearn/evaluation"
2014-08-22 07:21:24 +00:00
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
)
func TestRandomForest(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
})
})
})
})
Convey("Fitting with a Random Forest with too many features compared to the data", func() {
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
err = rf.Fit(inst)
Convey("Should return an error", func() {
So(err, ShouldNotBeNil)
})
})
})
}
2017-09-10 21:10:54 +01:00
func TestRandomForestSerialization(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
oldPredictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
Convey("Saving the model should work...", func() {
f, err := ioutil.TempFile(os.TempDir(), "rf")
So(err, ShouldBeNil)
2017-09-10 21:10:54 +01:00
err = rf.Save(f.Name())
defer func() {
f.Close()
}()
So(err, ShouldBeNil)
Convey("Loading the model should work...", func() {
newRf := NewRandomForest(10, 3)
err := newRf.Load(f.Name())
So(err, ShouldBeNil)
2021-01-10 00:56:45 -03:00
So(len(newRf.Model.Models), ShouldEqual, 10)
2017-09-10 21:10:54 +01:00
Convey("Predictions should be the same...", func() {
newPredictions, err := newRf.Predict(testData)
So(err, ShouldBeNil)
So(base.InstancesAreEqual(newPredictions, oldPredictions), ShouldBeTrue)
})
})
})
})
})
})
})
}