diff --git a/base/serialize.go b/base/serialize.go index 7f4ba7b..5cf7ebe 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -49,6 +49,7 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) { if err != nil { return nil, WrapError(err) } + if int64(len(ret)) != hdr.Size { if int64(len(ret)) < hdr.Size { log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err) @@ -56,10 +57,9 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) { return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size)) } } - if err != nil { - return nil, err - } + returnCandidate = ret + break } } if returnCandidate == nil { diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index d60e166..0d9492a 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -3,6 +3,7 @@ package ensemble import ( "errors" "fmt" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/meta" "github.com/sjwhitworth/golearn/trees" @@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error { func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { f.Model = new(meta.BaggedModel) + for i := 0; i < f.ForestSize; i++ { + tree := trees.NewID3DecisionTree(0.00) + f.Model.AddModel(tree) + } + return f.Model.LoadWithPrefix(reader, prefix) } diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 8d599e8..1e68391 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -3,12 +3,13 @@ package ensemble import ( "testing" + "io/ioutil" + "os" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" . "github.com/smartystreets/goconvey/convey" - "io/ioutil" - "os" ) func TestRandomForest(t *testing.T) { @@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) { newRf := NewRandomForest(10, 3) err := newRf.Load(f.Name()) So(err, ShouldBeNil) + So(len(newRf.Model.Models), ShouldEqual, 10) Convey("Predictions should be the same...", func() { newPredictions, err := newRf.Predict(testData) So(err, ShouldBeNil) diff --git a/trees/id3.go b/trees/id3.go index f37c42d..8574d48 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "sort" + "github.com/sjwhitworth/golearn/base" "github.com/sjwhitworth/golearn/evaluation" - "sort" ) // NodeType determines whether a DecisionTreeNode is a leaf or not. @@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error { func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { t.Root = &DecisionTreeNode{} - return t.Root.LoadWithPrefix(reader, "") + return t.Root.LoadWithPrefix(reader, prefix) } diff --git a/trees/id3_test.go b/trees/id3_test.go index 33d4ec5..37b9ad5 100644 --- a/trees/id3_test.go +++ b/trees/id3_test.go @@ -2,9 +2,10 @@ package trees import ( "fmt" + "testing" + "github.com/sjwhitworth/golearn/base" . "github.com/smartystreets/goconvey/convey" - "testing" ) func TestId3(t *testing.T) { @@ -52,10 +53,15 @@ func TestId3(t *testing.T) { _, err = id3tree.Predict(trainData) So(err, ShouldBeNil) + // Test save and load model err = id3tree.Save("tmp") So(err, ShouldBeNil) - err = id3tree.Load("tmp") - So(err, ShouldNotBeNil) // temp + id3tree = NewID3DecisionTree(0.1) + err = id3tree.Load("tmp") + So(err, ShouldBeNil) + + _, err = id3tree.Predict(trainData) + So(err, ShouldBeNil) }) }