From e7fee0a2d164678e91dcd761c56440d5b4711c41 Mon Sep 17 00:00:00 2001 From: Richard Townsend Date: Sun, 10 Sep 2017 21:10:54 +0100 Subject: [PATCH] Reformat, fix tests --- base/serialize.go | 23 +++++++++++----- base/serialize_test.go | 2 +- ensemble/randomforest.go | 37 ++++++++++++++++++++++++++ ensemble/randomforest_test.go | 50 +++++++++++++++++++++++++++++++++++ meta/bagging.go | 10 +++---- trees/id3.go | 2 +- 6 files changed, 108 insertions(+), 16 deletions(-) diff --git a/base/serialize.go b/base/serialize.go index 4e08a8d..8dc7aff 100644 --- a/base/serialize.go +++ b/base/serialize.go @@ -45,13 +45,15 @@ func (f *FunctionalTarReader) GetNamedFile(name string) ([]byte, error) { } if hdr.Name == name { - ret := make([]byte, hdr.Size) - n, err := tr.Read(ret) - if int64(n) != hdr.Size { - if int64(n) < hdr.Size { - log.Printf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size) + ret, err := ioutil.ReadAll(tr) + if err != nil { + return nil, WrapError(err) + } + if int64(len(ret)) != hdr.Size { + if int64(len(ret)) < hdr.Size { + log.Printf("Size mismatch, got %d byte(s) for %s, expected %d (err was %s)", len(ret), hdr.Name, hdr.Size, err) } else { - return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", n, hdr.Name, hdr.Size)) + return nil, WrapError(fmt.Errorf("Size mismatch, expected %d byte(s) for %s, got %d", len(ret), hdr.Name, hdr.Size)) } } if err != nil { @@ -248,7 +250,7 @@ func (c *ClassifierDeserializer) Close() { // ClassifierSerializer is an object used by SaveableClassifiers. type ClassifierSerializer struct { gzipWriter *gzip.Writer - fileWriter io.WriteCloser + fileWriter *os.File tarWriter *tar.Writer f *os.File filePath string @@ -274,6 +276,10 @@ func (c *ClassifierSerializer) Close() error { return fmt.Errorf("Could not close gz: %s", err) } + if err := c.fileWriter.Sync(); err != nil { + return fmt.Errorf("Could not close file writer: %s", err) + } + if err := c.fileWriter.Close(); err != nil { return fmt.Errorf("Could not close file writer: %s", err) } @@ -310,6 +316,9 @@ func (c *ClassifierSerializer) WriteBytesForKey(key string, b []byte) error { return fmt.Errorf("Could not write data for '%s': %s", key, err) } + c.tarWriter.Flush() + c.gzipWriter.Flush() + c.fileWriter.Sync() return nil } diff --git a/base/serialize_test.go b/base/serialize_test.go index 26eff70..a80d4ac 100644 --- a/base/serialize_test.go +++ b/base/serialize_test.go @@ -23,7 +23,7 @@ func TestSerializeToCSV(t *testing.T) { Convey("What's written out should match what's read in", func() { dinst, err := ParseCSVToInstances(f.Name(), true) So(err, ShouldBeNil) - So(inst.String(), ShouldEqual, dinst.String()) + So(InstancesAreEqual(inst, dinst), ShouldBeTrue) }) }) }) diff --git a/ensemble/randomforest.go b/ensemble/randomforest.go index 55feea0..d60e166 100644 --- a/ensemble/randomforest.go +++ b/ensemble/randomforest.go @@ -60,3 +60,40 @@ func (f *RandomForest) Predict(with base.FixedDataGrid) (base.FixedDataGrid, err func (f *RandomForest) String() string { return fmt.Sprintf("RandomForest(ForestSize: %d, Features:%d, %s\n)", f.ForestSize, f.Features, f.Model) } + +func (f *RandomForest) GetMetadata() base.ClassifierMetadataV1 { + return base.ClassifierMetadataV1{ + FormatVersion: 1, + ClassifierName: "KNN", + ClassifierVersion: "1.0", + ClassifierMetadata: nil, + } +} + +func (f *RandomForest) Save(filePath string) error { + writer, err := base.CreateSerializedClassifierStub(filePath, f.GetMetadata()) + if err != nil { + return err + } + + err = f.SaveWithPrefix(writer, "model") + writer.Close() + return err +} + +func (f *RandomForest) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error { + return f.Model.SaveWithPrefix(writer, prefix) +} + +func (f *RandomForest) Load(filePath string) error { + reader, err := base.ReadSerializedClassifierStub(filePath) + if err != nil { + return err + } + return f.LoadWithPrefix(reader, "model") +} + +func (f *RandomForest) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error { + f.Model = new(meta.BaggedModel) + return f.Model.LoadWithPrefix(reader, prefix) +} diff --git a/ensemble/randomforest_test.go b/ensemble/randomforest_test.go index 2789493..2915ca4 100644 --- a/ensemble/randomforest_test.go +++ b/ensemble/randomforest_test.go @@ -7,6 +7,8 @@ import ( "github.com/sjwhitworth/golearn/evaluation" "github.com/sjwhitworth/golearn/filters" . "github.com/smartystreets/goconvey/convey" + "io/ioutil" + "os" ) func TestRandomForest(t *testing.T) { @@ -53,3 +55,51 @@ func TestRandomForest(t *testing.T) { }) }) } + +func TestRandomForestSerialization(t *testing.T) { + Convey("Given a valid CSV file", t, func() { + inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true) + So(err, ShouldBeNil) + + Convey("When Chi-Merge filtering the data", func() { + filt := filters.NewChiMergeFilter(inst, 0.90) + for _, a := range base.NonClassFloatAttributes(inst) { + filt.AddAttribute(a) + } + filt.Train() + instf := base.NewLazilyFilteredInstances(inst, filt) + + Convey("Splitting the data into test and training sets", func() { + trainData, testData := base.InstancesTrainTestSplit(instf, 0.60) + + Convey("Fitting and predicting with a Random Forest", func() { + rf := NewRandomForest(10, 3) + err = rf.Fit(trainData) + So(err, ShouldBeNil) + + oldPredictions, err := rf.Predict(testData) + So(err, ShouldBeNil) + + Convey("Saving the model should work...", func() { + f, err := ioutil.TempFile(os.TempDir(), "rf") + err = rf.Save(f.Name()) + defer func() { + f.Close() + }() + So(err, ShouldBeNil) + Convey("Loading the model should work...", func() { + newRf := NewRandomForest(10, 3) + err := newRf.Load(f.Name()) + So(err, ShouldBeNil) + Convey("Predictions should be the same...", func() { + newPredictions, err := newRf.Predict(testData) + So(err, ShouldBeNil) + So(base.InstancesAreEqual(newPredictions, oldPredictions), ShouldBeTrue) + }) + }) + }) + }) + }) + }) + }) +} diff --git a/meta/bagging.go b/meta/bagging.go index bd16be5..0088c5c 100644 --- a/meta/bagging.go +++ b/meta/bagging.go @@ -253,7 +253,7 @@ func (b *BaggedModel) SaveWithPrefix(writer *base.ClassifierSerializer, prefix s // Save the classifiers for i, c := range b.Models { - clsPrefix := pI(writer.Prefix(prefix, "CLASSIFIERS"), i) + clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i)) err = c.SaveWithPrefix(writer, clsPrefix) if err != nil { return base.FormatError(err, "Can't save classifier %d", i) @@ -303,16 +303,12 @@ func (b *BaggedModel) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix if err != nil { return base.DescribeError("Can't read NUM_RANDOM_FEATURES", err) } - /*classifiersKey := reader.Prefix(prefix, "NUM_CLASSIFIERS") - numClassifiers, err := reader.GetU64ForKey(classifiersKey) - if err != nil { - return base.DescribeError("Can't read NUM_CLASSIFIERS", err) - }*/ + b.RandomFeatures = int(randomFeatures) // Reload the classifiers for i, m := range b.Models { - clsPrefix := pI(reader.Prefix(prefix, "CLASSIFIERS"), i) + clsPrefix := fmt.Sprintf("%s/", pI( "CLASSIFIERS", i)) err := m.LoadWithPrefix(reader, clsPrefix) if err != nil { return base.DescribeError("Can't read classifier", err) diff --git a/trees/id3.go b/trees/id3.go index c859423..97f2b26 100644 --- a/trees/id3.go +++ b/trees/id3.go @@ -657,7 +657,7 @@ func (t *ID3DecisionTree) String() string { func (t *ID3DecisionTree) GetMetadata() base.ClassifierMetadataV1 { return base.ClassifierMetadataV1{ FormatVersion: 1, - ClassifierName: "KNN", + ClassifierName: "ID3", ClassifierVersion: "1.0", ClassifierMetadata: nil, }