1
0
mirror of https://github.com/sjwhitworth/golearn.git synced 2025-04-25 13:48:49 +08:00

Reformat, fix tests

This commit is contained in:
Richard Townsend 2017-09-10 21:10:54 +01:00
parent fc110aab48
commit e7fee0a2d1
6 changed files with 108 additions and 16 deletions

View File

@ -45,13 +45,15 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) {
}
if hdr.Name == name {
ret := make([]byte, hdr.Size)
n, err := tr.Read(ret)
if int64(n) != hdr.Size {
if int64(n) < hdr.Size {
log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)
ret, err := ioutil.ReadAll(tr)
if err != nil {
return nil, WrapError(err)
}
if int64(len(ret)) != hdr.Size {
if int64(len(ret)) < hdr.Size {
log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err)
} else {
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size))
return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size))
}
}
if err != nil {
@ -248,7 +250,7 @@ func (c *ClassifierDeserializer) Close() {
// ClassifierSerializer is an object used by SaveableClassifiers.
type ClassifierSerializer struct {
gzipWriter *gzip.Writer
fileWriter io.WriteCloser
fileWriter *os.File
tarWriter *tar.Writer
f *os.File
filePath string
@ -274,6 +276,10 @@ func (c *ClassifierSerializer) Close() error {
return fmt.Errorf("Could not close gz: %s", err)
}
if err := c.fileWriter.Sync(); err != nil {
return fmt.Errorf("Could not close file writer: %s", err)
}
if err := c.fileWriter.Close(); err != nil {
return fmt.Errorf("Could not close file writer: %s", err)
}
@ -310,6 +316,9 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error {
return fmt.Errorf("Could not write data for '%s': %s", key, err)
}
c.tarWriter.Flush()
c.gzipWriter.Flush()
c.fileWriter.Sync()
return nil
}

View File

@ -23,7 +23,7 @@ func TestSerializeToCSV(t *testing.T) {
Convey("What's written out should match what's read in", func() {
dinst, err := ParseCSVToInstances(f.Name(), true)
So(err, ShouldBeNil)
So(inst.String(), ShouldEqual, dinst.String())
So(InstancesAreEqual(inst, dinst), ShouldBeTrue)
})
})
})

View File

@ -60,3 +60,40 @@ func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, err
func (f *RandomForest) String() string {
return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model)
}
func (f *RandomForest) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (f *RandomForest) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, f.GetMetadata())
if err != nil {
return err
}
err = f.SaveWithPrefix(writer, "model")
writer.Close()
return err
}
func (f *RandomForest) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
return f.Model.SaveWithPrefix(writer, prefix)
}
func (f *RandomForest) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
return f.LoadWithPrefix(reader, "model")
}
func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
f.Model = new(meta.BaggedModel)
return f.Model.LoadWithPrefix(reader, prefix)
}

View File

@ -7,6 +7,8 @@ import (
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
)
func TestRandomForest(t *testing.T) {
@ -53,3 +55,51 @@ func TestRandomForest(t *testing.T) {
})
})
}
func TestRandomForestSerialization(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
oldPredictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
Convey("Saving the model should work...", func() {
f, err := ioutil.TempFile(os.TempDir(), "rf")
err = rf.Save(f.Name())
defer func() {
f.Close()
}()
So(err, ShouldBeNil)
Convey("Loading the model should work...", func() {
newRf := NewRandomForest(10, 3)
err := newRf.Load(f.Name())
So(err, ShouldBeNil)
Convey("Predictions should be the same...", func() {
newPredictions, err := newRf.Predict(testData)
So(err, ShouldBeNil)
So(base.InstancesAreEqual(newPredictions, oldPredictions), ShouldBeTrue)
})
})
})
})
})
})
})
}

View File

@ -253,7 +253,7 @@ func (b *BaggedModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix s
// Save the classifiers
for i, c := range b.Models {
clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i)
clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i))
err = c.SaveWithPrefix(writer, clsPrefix)
if err != nil {
return base.FormatError(err, "Can't save classifier %d", i)
@ -303,16 +303,12 @@ func (b *BaggedModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix
if err != nil {
return base.DescribeError("Can't read NUM_RANDOM_FEATURES", err)
}
/*classifiersKey := reader.Prefix(prefix, "NUM_CLASSIFIERS")
numClassifiers, err := reader.GetU64ForKey(classifiersKey)
if err != nil {
return base.DescribeError("Can't read NUM_CLASSIFIERS", err)
}*/
b.RandomFeatures = int(randomFeatures)
// Reload the classifiers
for i, m := range b.Models {
clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i)
clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i))
err := m.LoadWithPrefix(reader, clsPrefix)
if err != nil {
return base.DescribeError("Can't read classifier", err)

View File

@ -657,7 +657,7 @@ func (t *ID3DecisionTree) String() string {
func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierName: "ID3",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}