From 6489b3bf7c135474883e1f5fef1fb823ac0ce5ad Mon Sep 17 00:00:00 2001 From: Francis Oliveira Date: Fri, 8 Jan 2021 00:22:00 -0300 Subject: [PATCH 1/2] Fix id3 model loading An error was ocurring everytime an id3 model was loaded --- base/serialize.go | 6 +++--- trees/id3_test.go | 12 +++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/base/serialize.go b/base/serialize.go index 7f4ba7b..5cf7ebe 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -49,6 +49,7 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) { if err != nil { return nil, WrapError(err) } + if int64(len(ret)) != hdr.Size { if int64(len(ret)) < hdr.Size { log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err) @@ -56,10 +57,9 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) { return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size)) } } - if err != nil { - return nil, err - } + returnCandidate = ret + break } } if returnCandidate == nil { diff --git a/trees/id3_test.go b/trees/id3_test.go index 33d4ec5..37b9ad5 100644 --- a/trees/id3_test.go +++ b/trees/id3_test.go @@ -2,9 +2,10 @@ package trees import ( "fmt" + "testing" + "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" - "testing" ) func TestId3(t *testing.T) { @@ -52,10 +53,15 @@ func TestId3(t *testing.T) { _, err = id3tree.Predict(trainData) So(err, ShouldBeNil) + // Test save and load model err = id3tree.Save("tmp") So(err, ShouldBeNil) - err = id3tree.Load("tmp") - So(err, ShouldNotBeNil) // temp + id3tree = NewID3DecisionTree(0.1) + err = id3tree.Load("tmp") + So(err, ShouldBeNil) + + _, err = id3tree.Predict(trainData) + So(err, ShouldBeNil) }) } From d33eb47a05bc935cb1e4ba668aed4f36231e5e16 Mon Sep 17 00:00:00 2001 From: Francis Oliveira Date: Sun, 10 Jan 2021 00:56:45 -0300 Subject: [PATCH 2/2] Fix random forest model loading --- ensemble/randomforest.go | 6 ++++++ ensemble/randomforest_test.go | 6 ++++-- trees/id3.go | 5 +++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index d60e166..0d9492a 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -3,6 +3,7 @@ package ensemble import ( "errors" "fmt" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/meta" "github.com/sjwhitworth/golearn/trees" @@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error { func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { f.Model = new(meta.BaggedModel) + for i := 0; i < f.ForestSize; i++ { + tree := trees.NewID3DecisionTree(0.00) + f.Model.AddModel(tree) + } + return f.Model.LoadWithPrefix(reader, prefix) } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 8d599e8..1e68391 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -3,12 +3,13 @@ package ensemble import ( "testing" + "io/ioutil" + "os" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" . "github.com/smartystreets/goconvey/convey" - "io/ioutil" - "os" ) func TestRandomForest(t *testing.T) { @@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) { newRf := NewRandomForest(10, 3) err := newRf.Load(f.Name()) So(err, ShouldBeNil) + So(len(newRf.Model.Models), ShouldEqual, 10) Convey("Predictions should be the same...", func() { newPredictions, err := newRf.Predict(testData) So(err, ShouldBeNil) diff --git a/trees/id3.go b/trees/id3.go index f37c42d..8574d48 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "sort" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" - "sort" ) // NodeType determines whether a DecisionTreeNode is a leaf or not. @@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error { func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { t.Root = &DecisionTreeNode{} - return t.Root.LoadWithPrefix(reader, "") + return t.Root.LoadWithPrefix(reader, prefix) }