mirror of
https://github.com/sjwhitworth/golearn.git
synced 2025-04-28 13:48:56 +08:00
Fix random forest model loading
This commit is contained in:
parent
6489b3bf7c
commit
d33eb47a05
@ -3,6 +3,7 @@ package ensemble
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"github.com/sjwhitworth/golearn/meta"
|
"github.com/sjwhitworth/golearn/meta"
|
||||||
"github.com/sjwhitworth/golearn/trees"
|
"github.com/sjwhitworth/golearn/trees"
|
||||||
@ -95,5 +96,10 @@ func (f *RandomForest) Load(filePath string) error {
|
|||||||
|
|
||||||
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||||
f.Model = new(meta.BaggedModel)
|
f.Model = new(meta.BaggedModel)
|
||||||
|
for i := 0; i < f.ForestSize; i++ {
|
||||||
|
tree := trees.NewID3DecisionTree(0.00)
|
||||||
|
f.Model.AddModel(tree)
|
||||||
|
}
|
||||||
|
|
||||||
return f.Model.LoadWithPrefix(reader, prefix)
|
return f.Model.LoadWithPrefix(reader, prefix)
|
||||||
}
|
}
|
||||||
|
@ -3,12 +3,13 @@ package ensemble
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"github.com/sjwhitworth/golearn/evaluation"
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
"github.com/sjwhitworth/golearn/filters"
|
"github.com/sjwhitworth/golearn/filters"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRandomForest(t *testing.T) {
|
func TestRandomForest(t *testing.T) {
|
||||||
@ -92,6 +93,7 @@ func TestRandomForestSerialization(t *testing.T) {
|
|||||||
newRf := NewRandomForest(10, 3)
|
newRf := NewRandomForest(10, 3)
|
||||||
err := newRf.Load(f.Name())
|
err := newRf.Load(f.Name())
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
So(len(newRf.Model.Models), ShouldEqual, 10)
|
||||||
Convey("Predictions should be the same...", func() {
|
Convey("Predictions should be the same...", func() {
|
||||||
newPredictions, err := newRf.Predict(testData)
|
newPredictions, err := newRf.Predict(testData)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
@ -4,9 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/sjwhitworth/golearn/base"
|
"github.com/sjwhitworth/golearn/base"
|
||||||
"github.com/sjwhitworth/golearn/evaluation"
|
"github.com/sjwhitworth/golearn/evaluation"
|
||||||
"sort"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NodeType determines whether a DecisionTreeNode is a leaf or not.
|
// NodeType determines whether a DecisionTreeNode is a leaf or not.
|
||||||
@ -587,5 +588,5 @@ func (t *ID3DecisionTree) Load(filePath string) error {
|
|||||||
|
|
||||||
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
func (t *ID3DecisionTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
|
||||||
t.Root = &DecisionTreeNode{}
|
t.Root = &DecisionTreeNode{}
|
||||||
return t.Root.LoadWithPrefix(reader, "")
|
return t.Root.LoadWithPrefix(reader, prefix)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user