1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

Fix random forest model loading

This commit is contained in:
Francis Oliveira 2021-01-10 00:56:45 -03:00
parent 6489b3bf7c
commit d33eb47a05
3 changed files with 13 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package ensemble
import (
"errors"
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/meta"
"github.com/sjwhitworth/golearn/trees"
@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error {
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
f.Model = new(meta.BaggedModel)
for i := 0; i < f.ForestSize; i++ {
tree := trees.NewID3DecisionTree(0.00)
f.Model.AddModel(tree)
}
return f.Model.LoadWithPrefix(reader, prefix)
}

View File

@ -3,12 +3,13 @@ package ensemble
import (
"testing"
"io/ioutil"
"os"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
)
func TestRandomForest(t *testing.T) {
@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) {
newRf := NewRandomForest(10, 3)
err := newRf.Load(f.Name())
So(err, ShouldBeNil)
So(len(newRf.Model.Models), ShouldEqual, 10)
Convey("Predictions should be the same...", func() {
newPredictions, err := newRf.Predict(testData)
So(err, ShouldBeNil)

View File

@ -4,9 +4,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"sort"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"sort"
)
// NodeType determines whether a DecisionTreeNode is a leaf or not.
@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error {
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
t.Root = &DecisionTreeNode{}
return t.Root.LoadWithPrefix(reader, "")
return t.Root.LoadWithPrefix(reader, prefix)
}