1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-26 13:49:14 +08:00

Merge pull request #261 from Oliveirakun/fix-model-load

Fix model loading issues for decision tree and random forest
This commit is contained in:
Richard Townsend 2021-01-17 16:49:41 +00:00 committed by GitHub
commit cde96fa826
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 10 deletions

View File

@ -49,6 +49,7 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
if err != nil {
return nil, WrapError(err)
}
if int64(len(ret)) != hdr.Size {
if int64(len(ret)) < hdr.Size {
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
@ -56,10 +57,9 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
}
}
if err != nil {
return nil, err
}
returnCandidate = ret
break
}
}
if returnCandidate == nil {

View File

@ -3,6 +3,7 @@ package ensemble
import (
"errors"
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/meta"
"github.com/sjwhitworth/golearn/trees"
@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error {
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
f.Model = new(meta.BaggedModel)
for i := 0; i < f.ForestSize; i++ {
tree := trees.NewID3DecisionTree(0.00)
f.Model.AddModel(tree)
}
return f.Model.LoadWithPrefix(reader, prefix)
}

View File

@ -3,12 +3,13 @@ package ensemble
import (
"testing"
"io/ioutil"
"os"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
)
func TestRandomForest(t *testing.T) {
@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) {
newRf := NewRandomForest(10, 3)
err := newRf.Load(f.Name())
So(err, ShouldBeNil)
So(len(newRf.Model.Models), ShouldEqual, 10)
Convey("Predictions should be the same...", func() {
newPredictions, err := newRf.Predict(testData)
So(err, ShouldBeNil)

View File

@ -4,9 +4,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"sort"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"sort"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not.
@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error {
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
t.Root = &DecisionTreeNode{}
return t.Root.LoadWithPrefix(reader, "")
return t.Root.LoadWithPrefix(reader, prefix)
}

View File

@ -2,9 +2,10 @@ package trees
import (
"fmt"
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestId3(t *testing.T) {
@ -52,10 +53,15 @@ func TestId3(t *testing.T) {
_, err = id3tree.Predict(trainData)
So(err, ShouldBeNil)
// Test save and load model
err = id3tree.Save("tmp")
So(err, ShouldBeNil)
err = id3tree.Load("tmp")
So(err, ShouldNotBeNil) // temp
id3tree = NewID3DecisionTree(0.1)
err = id3tree.Load("tmp")
So(err, ShouldBeNil)
_, err = id3tree.Predict(trainData)
So(err, ShouldBeNil)
})
}