mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-26 13:49:14 +08:00
Merge pull request #261 from Oliveirakun/fix-model-load
Fix model loading issues for decision tree and random forest
This commit is contained in:
commit
cde96fa826
@ -49,6 +49,7 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, WrapError(err)
|
||||
}
|
||||
|
||||
if int64(len(ret)) != hdr.Size {
|
||||
if int64(len(ret)) < hdr.Size {
|
||||
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
|
||||
@ -56,10 +57,9 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
|
||||
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
returnCandidate = ret
|
||||
break
|
||||
}
|
||||
}
|
||||
if returnCandidate == nil {
|
||||
|
@ -3,6 +3,7 @@ package ensemble
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/meta"
|
||||
"github.com/sjwhitworth/golearn/trees"
|
||||
@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error {
|
||||
|
||||
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
tree := trees.NewID3DecisionTree(0.00)
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
|
||||
return f.Model.LoadWithPrefix(reader, prefix)
|
||||
}
|
||||
|
@ -3,12 +3,13 @@ package ensemble
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TestRandomForest(t *testing.T) {
|
||||
@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) {
|
||||
newRf := NewRandomForest(10, 3)
|
||||
err := newRf.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
So(len(newRf.Model.Models), ShouldEqual, 10)
|
||||
Convey("Predictions should be the same...", func() {
|
||||
newPredictions, err := newRf.Predict(testData)
|
||||
So(err, ShouldBeNil)
|
||||
|
@ -4,9 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not.
|
||||
@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error {
|
||||
|
||||
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
t.Root = &DecisionTreeNode{}
|
||||
return t.Root.LoadWithPrefix(reader, "")
|
||||
return t.Root.LoadWithPrefix(reader, prefix)
|
||||
}
|
||||
|
@ -2,9 +2,10 @@ package trees
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestId3(t *testing.T) {
|
||||
@ -52,10 +53,15 @@ func TestId3(t *testing.T) {
|
||||
_, err = id3tree.Predict(trainData)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Test save and load model
|
||||
err = id3tree.Save("tmp")
|
||||
So(err, ShouldBeNil)
|
||||
err = id3tree.Load("tmp")
|
||||
So(err, ShouldNotBeNil) // temp
|
||||
|
||||
id3tree = NewID3DecisionTree(0.1)
|
||||
err = id3tree.Load("tmp")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
_, err = id3tree.Predict(trainData)
|
||||
So(err, ShouldBeNil)
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user