From d33eb47a05bc935cb1e4ba668aed4f36231e5e16 Mon Sep 17 00:00:00 2001 From: Francis Oliveira Date: Sun, 10 Jan 2021 00:56:45 -0300 Subject: [PATCH] Fix random forest model loading --- ensemble/randomforest.go | 6 ++++++ ensemble/randomforest_test.go | 6 ++++-- trees/id3.go | 5 +++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index d60e166..0d9492a 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -3,6 +3,7 @@ package ensemble import ( "errors" "fmt" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/meta" "github.com/sjwhitworth/golearn/trees" @@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error { func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { f.Model = new(meta.BaggedModel) + for i := 0; i < f.ForestSize; i++ { + tree := trees.NewID3DecisionTree(0.00) + f.Model.AddModel(tree) + } + return f.Model.LoadWithPrefix(reader, prefix) } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 8d599e8..1e68391 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -3,12 +3,13 @@ package ensemble import ( "testing" + "io/ioutil" + "os" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" . "github.com/smartystreets/goconvey/convey" - "io/ioutil" - "os" ) func TestRandomForest(t *testing.T) { @@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) { newRf := NewRandomForest(10, 3) err := newRf.Load(f.Name()) So(err, ShouldBeNil) + So(len(newRf.Model.Models), ShouldEqual, 10) Convey("Predictions should be the same...", func() { newPredictions, err := newRf.Predict(testData) So(err, ShouldBeNil) diff --git a/trees/id3.go b/trees/id3.go index f37c42d..8574d48 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "sort" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" - "sort" ) // NodeType determines whether a DecisionTreeNode is a leaf or not. @@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error { func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { t.Root = &DecisionTreeNode{} - return t.Root.LoadWithPrefix(reader, "") + return t.Root.LoadWithPrefix(reader, prefix) }