mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-25 13:48:49 +08:00
Fix random forest model loading
This commit is contained in:
parent
6489b3bf7c
commit
d33eb47a05
@ -3,6 +3,7 @@ package ensemble
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/meta"
|
||||
"github.com/sjwhitworth/golearn/trees"
|
||||
@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error {
|
||||
|
||||
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
f.Model = new(meta.BaggedModel)
|
||||
for i := 0; i < f.ForestSize; i++ {
|
||||
tree := trees.NewID3DecisionTree(0.00)
|
||||
f.Model.AddModel(tree)
|
||||
}
|
||||
|
||||
return f.Model.LoadWithPrefix(reader, prefix)
|
||||
}
|
||||
|
@ -3,12 +3,13 @@ package ensemble
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"github.com/sjwhitworth/golearn/filters"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func TestRandomForest(t *testing.T) {
|
||||
@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) {
|
||||
newRf := NewRandomForest(10, 3)
|
||||
err := newRf.Load(f.Name())
|
||||
So(err, ShouldBeNil)
|
||||
So(len(newRf.Model.Models), ShouldEqual, 10)
|
||||
Convey("Predictions should be the same...", func() {
|
||||
newPredictions, err := newRf.Predict(testData)
|
||||
So(err, ShouldBeNil)
|
||||
|
@ -4,9 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/sjwhitworth/golearn/base"
|
||||
"github.com/sjwhitworth/golearn/evaluation"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// NodeType determines whether a DecisionTreeNode is a leaf or not.
|
||||
@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error {
|
||||
|
||||
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||
t.Root = &DecisionTreeNode{}
|
||||
return t.Root.LoadWithPrefix(reader, "")
|
||||
return t.Root.LoadWithPrefix(reader, prefix)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user