1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

Fix id3 model loading

An error was ocurring everytime an id3 model was loaded
This commit is contained in:
Francis Oliveira 2021-01-08 00:22:00 -03:00
parent 294d65fca3
commit 6489b3bf7c
2 changed files with 12 additions and 6 deletions

View File

@ -49,6 +49,7 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
if err != nil {
return nil, WrapError(err)
}
if int64(len(ret)) != hdr.Size {
if int64(len(ret)) < hdr.Size {
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
@ -56,10 +57,9 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
}
}
if err != nil {
return nil, err
}
returnCandidate = ret
break
}
}
if returnCandidate == nil {

View File

@ -2,9 +2,10 @@ package trees
import (
"fmt"
"testing"
"github.com/sjwhitworth/golearn/base"
. "github.com/smartystreets/goconvey/convey"
"testing"
)
func TestId3(t *testing.T) {
@ -52,10 +53,15 @@ func TestId3(t *testing.T) {
_, err = id3tree.Predict(trainData)
So(err, ShouldBeNil)
// Test save and load model
err = id3tree.Save("tmp")
So(err, ShouldBeNil)
err = id3tree.Load("tmp")
So(err, ShouldNotBeNil) // temp
id3tree = NewID3DecisionTree(0.1)
err = id3tree.Load("tmp")
So(err, ShouldBeNil)
_, err = id3tree.Predict(trainData)
So(err, ShouldBeNil)
})
}